From 4e2df75448d01d28ae6a58906b38371ddad98a9d Mon Sep 17 00:00:00 2001 From: Jintao Huang Date: Wed, 29 Apr 2026 17:42:35 +0800 Subject: [PATCH 01/26] update --- src/mcore_bridge/config/model_config.py | 5 +++-- src/mcore_bridge/model/constant.py | 1 + src/mcore_bridge/model/mm_gpts/__init__.py | 2 +- 3 files changed, 5 insertions(+), 3 deletions(-) diff --git a/src/mcore_bridge/config/model_config.py b/src/mcore_bridge/config/model_config.py index 943c100..5eec352 100644 --- a/src/mcore_bridge/config/model_config.py +++ b/src/mcore_bridge/config/model_config.py @@ -3,13 +3,13 @@ import os import re import torch.nn.functional as F -from dataclasses import dataclass +from dataclasses import dataclass, field from megatron.core import mpu from megatron.core.transformer import TransformerConfig from transformers import PretrainedConfig from transformers.utils import is_torch_npu_available from transformers.utils.versions import require_version -from typing import List, Literal, Optional, Union +from typing import Any, Dict, List, Literal, Optional, Union from mcore_bridge.utils import get_logger, json_parse_to_dict @@ -229,6 +229,7 @@ class ModelConfig(TransformerConfig): task_type: Literal['causal_lm', 'seq_cls', 'embedding', 'generative_reranker'] = 'causal_lm' num_labels: Optional[int] = None mlp_padding_free: bool = False + model_kwargs: Dict[str, Any] = field(default_factory=dict) _mindspeed_defaults_cache = None diff --git a/src/mcore_bridge/model/constant.py b/src/mcore_bridge/model/constant.py index 6c61f09..36f9d81 100644 --- a/src/mcore_bridge/model/constant.py +++ b/src/mcore_bridge/model/constant.py @@ -27,6 +27,7 @@ class MLLMModelType: glm4v_moe = 'glm4v_moe' kimi_vl = 'kimi_vl' llama4 = 'llama4' + gemma4 = 'gemma4' kimi_k25 = 'kimi_k25' diff --git a/src/mcore_bridge/model/mm_gpts/__init__.py b/src/mcore_bridge/model/mm_gpts/__init__.py index d13e4e7..b8ea385 100644 --- a/src/mcore_bridge/model/mm_gpts/__init__.py +++ b/src/mcore_bridge/model/mm_gpts/__init__.py @@ -1,2 +1,2 @@ # Copyright (c) ModelScope Contributors. All rights reserved. -from . import glm, internvl, kimi_vl, llama4, llava, qwen, qwen3_5, qwen3_5_gdn, qwen3_omni, qwen3_vl +from . import gemma4, glm, internvl, kimi_vl, llama4, llava, qwen, qwen3_5, qwen3_5_gdn, qwen3_omni, qwen3_vl From c388954d3a86fa9764226fed3b886df0af9497af Mon Sep 17 00:00:00 2001 From: Jintao Huang Date: Wed, 29 Apr 2026 17:49:36 +0800 Subject: [PATCH 02/26] update --- src/mcore_bridge/model/mm_gpts/gemma4.py | 63 ++++++++++++++++++++++++ 1 file changed, 63 insertions(+) create mode 100644 src/mcore_bridge/model/mm_gpts/gemma4.py diff --git a/src/mcore_bridge/model/mm_gpts/gemma4.py b/src/mcore_bridge/model/mm_gpts/gemma4.py new file mode 100644 index 0000000..9b99428 --- /dev/null +++ b/src/mcore_bridge/model/mm_gpts/gemma4.py @@ -0,0 +1,63 @@ +# Copyright (c) ModelScope Contributors. All rights reserved. +from transformers import AutoModel, PretrainedConfig + +from mcore_bridge.bridge import GPTBridge + +from ..constant import ModelType +from ..register import ModelLoader, ModelMeta, register_model +from .utils import HuggingFaceVit + + +class Gemma4Vit(HuggingFaceVit): + module_mapping = { + 'model.vision_tower': 'vision_tower', + 'model.embed_vision': 'embed_vision', + 'model.audio_tower': 'audio_tower', + 'model.embed_audio': 'embed_audio', + } + _vision_tower = ['vision_tower', 'audio_tower'] + _aligner = ['embed_vision', 'embed_audio'] + support_multimodal = False + + def prepare_model(self, hf_config: PretrainedConfig): + from transformers.models.gemma4.modeling_gemma4 import Gemma4MultimodalEmbedder + self.vision_tower = AutoModel.from_config(hf_config.vision_config) + self.vocab_size = hf_config.text_config.vocab_size + + language_model = AutoModel.from_config(config=hf_config.text_config) + self.language_model = language_model + self.vocab_size_per_layer_input = hf_config.text_config.vocab_size_per_layer_input + self.audio_tower = AutoModel.from_config(hf_config.audio_config) if hf_config.audio_config is not None else None + self.embed_vision = ( + Gemma4MultimodalEmbedder(hf_config.vision_config, hf_config.text_config) + if hf_config.vision_config is not None else None) + self.embed_audio = ( + Gemma4MultimodalEmbedder(hf_config.audio_config, hf_config.text_config) + if hf_config.audio_config is not None else None) + + def get_inputs_embeds(self, inputs_embeds, **kwargs): + return inputs_embeds + + +class Gemma4Bridge(GPTBridge): + pass + + +class Gemma4Loader(ModelLoader): + pass + # def get_transformer_layer_spec(self, vp_stage: Optional[int] = None): + # layer_specs = get_gpt_decoder_block_spec( + # self.config, use_transformer_engine=True, normalization=self.config.normalization, vp_stage=vp_stage) + # for layer_spec in layer_specs.layer_specs: + # pass + # return layer_specs + + +register_model( + ModelMeta( + ModelType.gemma4, + ['gemma4'], + bridge_cls=Gemma4Bridge, + visual_cls=Gemma4Vit, + loader=Gemma4Loader, + )) From 76af2bcebdf2c1011951a219a4bcda87103185ae Mon Sep 17 00:00:00 2001 From: Jintao Huang Date: Wed, 29 Apr 2026 19:19:39 +0800 Subject: [PATCH 03/26] update --- src/mcore_bridge/model/gpt_model.py | 65 ++++++++++++++++------------- 1 file changed, 37 insertions(+), 28 deletions(-) diff --git a/src/mcore_bridge/model/gpt_model.py b/src/mcore_bridge/model/gpt_model.py index 5fb714c..e34c60a 100644 --- a/src/mcore_bridge/model/gpt_model.py +++ b/src/mcore_bridge/model/gpt_model.py @@ -110,9 +110,7 @@ def __init__( for i in range(len(self.decoder.layers)): if hasattr(self.decoder.layers[i].self_attention, 'rotary_pos_emb'): del self.decoder.layers[i].self_attention.rotary_pos_emb - self.attention_scaling = 1. - new_inv_freq, self.attention_scaling = get_rope_inv_freq(config) - self.rotary_pos_emb.inv_freq = new_inv_freq.to(self.rotary_pos_emb.inv_freq.device) + self._set_inv_freq() if self.config.task_type == 'seq_cls' and self.post_process: self.output_layer = OutputLayerLinear( config.hidden_size, @@ -222,7 +220,36 @@ def _preprocess( if decoder_input is not None and self.training and torch.is_grad_enabled() and not decoder_input.requires_grad: # fix LoRA incompatibility with gradient checkpointing decoder_input = decoder_input.requires_grad_(True) + rotary_pos_emb, rotary_pos_cos, decoder_rotary_pos_emb, rotary_pos_sin = self._get_rotary_pos_emb( + decoder_input, position_ids, packed_seq_params=packed_seq_params) + + if (in_inference_mode and ((self.config.enable_cuda_graph and self.config.cuda_graph_scope != 'full_iteration') + or self.config.flash_decode) and rotary_pos_cos is not None + and inference_context.is_static_batching()): + current_batch_size = input_ids.shape[0] + sequence_len_offset = torch.tensor( + [inference_context.sequence_len_offset] * current_batch_size, + dtype=torch.int32, + device=rotary_pos_cos.device, # Co-locate this with the rotary tensors + ) + else: + sequence_len_offset = None + + # Wrap decoder_input to allow the decoder (TransformerBlock) to delete the + # reference held by this caller function, enabling early garbage collection for + # inference. Skip wrapping if decoder_input is logged after decoder completion. + if in_inference_mode and not has_config_logger_enabled(self.config): + decoder_input = WrappedTensor(decoder_input) + return (decoder_input, rotary_pos_emb, decoder_rotary_pos_emb, rotary_pos_cos, rotary_pos_sin, + sequence_len_offset) + + def _set_inv_freq(self): + self.attention_scaling = 1. + new_inv_freq, self.attention_scaling = get_rope_inv_freq(self.config) + self.rotary_pos_emb.inv_freq = new_inv_freq.to(self.rotary_pos_emb.inv_freq.device) + + def _get_rotary_pos_emb(self, decoder_input, position_ids, packed_seq_params, inference_context=None): # Rotary positional embeddings (embedding is None for PP intermediate devices) rotary_pos_emb = None rotary_pos_cos = None @@ -257,26 +284,13 @@ def _preprocess( rotary_seq_len, packed_seq=packed_seq, ) + decoder_rotary_pos_emb = rotary_pos_emb + packed_seq = packed_seq_params is not None and packed_seq_params.qkv_format == 'thd' + if self.position_embedding_type == 'rope' and packed_seq and not self.config.apply_rope_fusion: + assert position_ids.shape[0] == 1, f'position_ids.shape: {position_ids.shape}' + decoder_rotary_pos_emb = rotary_pos_emb[position_ids[0]] - if (in_inference_mode and ((self.config.enable_cuda_graph and self.config.cuda_graph_scope != 'full_iteration') - or self.config.flash_decode) and rotary_pos_cos is not None - and inference_context.is_static_batching()): - current_batch_size = input_ids.shape[0] - sequence_len_offset = torch.tensor( - [inference_context.sequence_len_offset] * current_batch_size, - dtype=torch.int32, - device=rotary_pos_cos.device, # Co-locate this with the rotary tensors - ) - else: - sequence_len_offset = None - - # Wrap decoder_input to allow the decoder (TransformerBlock) to delete the - # reference held by this caller function, enabling early garbage collection for - # inference. Skip wrapping if decoder_input is logged after decoder completion. - if in_inference_mode and not has_config_logger_enabled(self.config): - decoder_input = WrappedTensor(decoder_input) - - return decoder_input, rotary_pos_emb, rotary_pos_cos, rotary_pos_sin, sequence_len_offset + return rotary_pos_emb, decoder_rotary_pos_emb, rotary_pos_cos, rotary_pos_sin # Code borrowed from NVIDIA/Megatron-LM def forward( @@ -308,7 +322,7 @@ def forward( inference_context = deprecate_inference_params(inference_context, inference_params) - decoder_input, rotary_pos_emb, rotary_pos_cos, rotary_pos_sin, sequence_len_offset = ( + decoder_input, rotary_pos_emb, decoder_rotary_pos_emb, rotary_pos_cos, rotary_pos_sin, sequence_len_offset = ( self._preprocess( input_ids=input_ids, position_ids=position_ids, @@ -316,11 +330,6 @@ def forward( inference_context=inference_context, packed_seq_params=packed_seq_params, )) - decoder_rotary_pos_emb = rotary_pos_emb - packed_seq = packed_seq_params is not None and packed_seq_params.qkv_format == 'thd' - if self.position_embedding_type == 'rope' and packed_seq and not self.config.apply_rope_fusion: - assert position_ids.shape[0] == 1, f'position_ids.shape: {position_ids.shape}' - decoder_rotary_pos_emb = rotary_pos_emb[position_ids[0]] mtp_decoder_input = decoder_input if self.config.is_multimodal and self.config.mtp_num_layers and decoder_input is None: From 54e33435585c2bdd1f5045c162b238e98f0565ba Mon Sep 17 00:00:00 2001 From: Jintao Huang Date: Wed, 29 Apr 2026 21:47:11 +0800 Subject: [PATCH 04/26] update --- src/mcore_bridge/tuners/patcher.py | 39 +++--------------------------- src/mcore_bridge/utils/__init__.py | 2 +- src/mcore_bridge/utils/utils.py | 36 +++++++++++++++++++++++++++ 3 files changed, 40 insertions(+), 37 deletions(-) diff --git a/src/mcore_bridge/tuners/patcher.py b/src/mcore_bridge/tuners/patcher.py index e715c35..f9cae8c 100644 --- a/src/mcore_bridge/tuners/patcher.py +++ b/src/mcore_bridge/tuners/patcher.py @@ -1,6 +1,4 @@ # Copyright (c) ModelScope Contributors. All rights reserved. -import copy -from contextlib import contextmanager from megatron.core.extensions.transformer_engine import TEGroupedLinear, TELayerNormColumnParallelLinear, TELinear from megatron.core.transformer.module import MegatronModule from megatron.core.transformer.moe.router import TopKRouter @@ -11,6 +9,8 @@ from torch import nn from typing import Optional +from mcore_bridge.utils import patch_deepcopy + from .lora import LoraParallelLinear @@ -37,39 +37,6 @@ def dispatch_megatron( model.dispatch_megatron = dispatch_megatron -@contextmanager -def _patch_deepcopy(): - _origin_deepcopy = copy.deepcopy - copy_keys = ('tp_group', '_tp_group', 'config') - - def new_deepcopy(x, *args, **kwargs): - if not isinstance(x, nn.Module): - return _origin_deepcopy(x, *args, **kwargs) - - saved_values = {} - for key in copy_keys: - val = getattr(x, key, None) - if val is not None: - saved_values[key] = val - setattr(x, key, None) - - try: - res = _origin_deepcopy(x, *args, **kwargs) - finally: - for key, value in saved_values.items(): - setattr(x, key, value) - - for key, value in saved_values.items(): - setattr(res, key, value) - return res - - copy.deepcopy = new_deepcopy - try: - yield - finally: - copy.deepcopy = _origin_deepcopy - - def _patch_lora_model(): if hasattr(LoraModel, '_mcore_patched'): return @@ -77,7 +44,7 @@ def _patch_lora_model(): __origin_init__ = LoraModel.__init__ def __new_init__(self, *args, **kwargs): - with _patch_deepcopy(): + with patch_deepcopy(): __origin_init__(self, *args, **kwargs) if not isinstance(self.model, MegatronModule): return diff --git a/src/mcore_bridge/utils/__init__.py b/src/mcore_bridge/utils/__init__.py index 34cdbd1..d4285be 100644 --- a/src/mcore_bridge/utils/__init__.py +++ b/src/mcore_bridge/utils/__init__.py @@ -6,4 +6,4 @@ from .megatron_utils import get_local_layer_specs, set_random_seed, split_cp_inputs, unwrap_model from .safetensors import SafetensorLazyLoader, StreamingSafetensorSaver from .torch_utils import gc_collect, get_current_device, safe_ddp_context, to_device -from .utils import deep_getattr, get_env_args, json_parse_to_dict +from .utils import deep_getattr, get_env_args, json_parse_to_dict, patch_deepcopy diff --git a/src/mcore_bridge/utils/utils.py b/src/mcore_bridge/utils/utils.py index 6905c41..a7e525b 100644 --- a/src/mcore_bridge/utils/utils.py +++ b/src/mcore_bridge/utils/utils.py @@ -1,5 +1,8 @@ +import copy import json import os +from contextlib import contextmanager +from torch import nn from transformers.utils import strtobool from typing import Callable, Dict, Optional, TypeVar, Union @@ -58,3 +61,36 @@ def deep_getattr(obj, attr: str, default=None): else: obj = getattr(obj, a, default) return obj + + +@contextmanager +def patch_deepcopy(): + _origin_deepcopy = copy.deepcopy + copy_keys = ('tp_group', '_tp_group', 'config') + + def new_deepcopy(x, *args, **kwargs): + if not isinstance(x, nn.Module): + return _origin_deepcopy(x, *args, **kwargs) + + saved_values = {} + for key in copy_keys: + val = getattr(x, key, None) + if val is not None: + saved_values[key] = val + setattr(x, key, None) + + try: + res = _origin_deepcopy(x, *args, **kwargs) + finally: + for key, value in saved_values.items(): + setattr(x, key, value) + + for key, value in saved_values.items(): + setattr(res, key, value) + return res + + copy.deepcopy = new_deepcopy + try: + yield + finally: + copy.deepcopy = _origin_deepcopy From 25a45bda3aa7dde0b9125c33614f94bdc0c4c18e Mon Sep 17 00:00:00 2001 From: Jintao Huang Date: Thu, 30 Apr 2026 10:13:18 +0800 Subject: [PATCH 05/26] update --- src/mcore_bridge/model/gpt_model.py | 2 +- src/mcore_bridge/model/mm_gpt_model.py | 4 ++- src/mcore_bridge/model/mm_gpts/gemma4.py | 41 ++++++++++++++++++------ src/mcore_bridge/model/rope.py | 4 +-- 4 files changed, 38 insertions(+), 13 deletions(-) diff --git a/src/mcore_bridge/model/gpt_model.py b/src/mcore_bridge/model/gpt_model.py index e34c60a..ace31e4 100644 --- a/src/mcore_bridge/model/gpt_model.py +++ b/src/mcore_bridge/model/gpt_model.py @@ -220,7 +220,7 @@ def _preprocess( if decoder_input is not None and self.training and torch.is_grad_enabled() and not decoder_input.requires_grad: # fix LoRA incompatibility with gradient checkpointing decoder_input = decoder_input.requires_grad_(True) - rotary_pos_emb, rotary_pos_cos, decoder_rotary_pos_emb, rotary_pos_sin = self._get_rotary_pos_emb( + rotary_pos_emb, decoder_rotary_pos_emb, rotary_pos_cos, rotary_pos_sin = self._get_rotary_pos_emb( decoder_input, position_ids, packed_seq_params=packed_seq_params) if (in_inference_mode and ((self.config.enable_cuda_graph and self.config.cuda_graph_scope != 'full_iteration') diff --git a/src/mcore_bridge/model/mm_gpt_model.py b/src/mcore_bridge/model/mm_gpt_model.py index b68e82b..b3fc0d6 100644 --- a/src/mcore_bridge/model/mm_gpt_model.py +++ b/src/mcore_bridge/model/mm_gpt_model.py @@ -18,6 +18,7 @@ class MultimodalGPTModel(MegatronModule): + language_model_cls = GPTModel def __init__(self, config: ModelConfig, @@ -29,7 +30,8 @@ def __init__(self, super().__init__(config) self.pre_process = pre_process self.post_process = post_process - self.language_model = GPTModel(config, transformer_layer_spec, pre_process, post_process, *_args, **kwargs) + self.language_model = self.language_model_cls(config, transformer_layer_spec, pre_process, post_process, *_args, + **kwargs) self.vp_stage = self.language_model.vp_stage self.share_embeddings_and_output_weights = self.language_model.share_embeddings_and_output_weights self.model_meta = config.model_meta diff --git a/src/mcore_bridge/model/mm_gpts/gemma4.py b/src/mcore_bridge/model/mm_gpts/gemma4.py index 9b99428..2d656dd 100644 --- a/src/mcore_bridge/model/mm_gpts/gemma4.py +++ b/src/mcore_bridge/model/mm_gpts/gemma4.py @@ -1,10 +1,14 @@ # Copyright (c) ModelScope Contributors. All rights reserved. +import copy from transformers import AutoModel, PretrainedConfig from mcore_bridge.bridge import GPTBridge from ..constant import ModelType +from ..gpt_model import GPTModel +from ..mm_gpt_model import MultimodalGPTModel from ..register import ModelLoader, ModelMeta, register_model +from ..rope import get_rope_inv_freq from .utils import HuggingFaceVit @@ -22,15 +26,8 @@ class Gemma4Vit(HuggingFaceVit): def prepare_model(self, hf_config: PretrainedConfig): from transformers.models.gemma4.modeling_gemma4 import Gemma4MultimodalEmbedder self.vision_tower = AutoModel.from_config(hf_config.vision_config) - self.vocab_size = hf_config.text_config.vocab_size - - language_model = AutoModel.from_config(config=hf_config.text_config) - self.language_model = language_model - self.vocab_size_per_layer_input = hf_config.text_config.vocab_size_per_layer_input self.audio_tower = AutoModel.from_config(hf_config.audio_config) if hf_config.audio_config is not None else None - self.embed_vision = ( - Gemma4MultimodalEmbedder(hf_config.vision_config, hf_config.text_config) - if hf_config.vision_config is not None else None) + self.embed_vision = Gemma4MultimodalEmbedder(hf_config.vision_config, hf_config.text_config) self.embed_audio = ( Gemma4MultimodalEmbedder(hf_config.audio_config, hf_config.text_config) if hf_config.audio_config is not None else None) @@ -43,8 +40,34 @@ class Gemma4Bridge(GPTBridge): pass +class Gemma4TextGPTModel(GPTModel): + + def _set_inv_freq(self): + rope_scaling = self.config.rope_scaling + self.config.rope_scaling = rope_scaling['sliding_attention'] + new_inv_freq, attention_scaling = get_rope_inv_freq(self.config) + assert attention_scaling == 1, 'not support' + self.rotary_pos_emb.inv_freq = new_inv_freq.to(self.rotary_pos_emb.inv_freq.device) + # full + self.full_rotary_pos_emb = copy.copy(self.rotary_pos_emb) + self.config.rope_scaling = rope_scaling['full_attention'] + kwargs = {} + if self.config.rope_scaling['rope_type'] == 'proportional': + kwargs['head_dim_key'] = 'global_head_dim' + new_inv_freq, attention_scaling = get_rope_inv_freq(self.config, **kwargs) + assert attention_scaling == 1, 'not support' + self.full_rotary_pos_emb.inv_freq = new_inv_freq + self.attention_scaling = attention_scaling + + self.config.rope_scaling = rope_scaling + + +class Gemma4GPTModel(MultimodalGPTModel): + language_model_cls = Gemma4TextGPTModel + + class Gemma4Loader(ModelLoader): - pass + model_cls = Gemma4GPTModel # def get_transformer_layer_spec(self, vp_stage: Optional[int] = None): # layer_specs = get_gpt_decoder_block_spec( # self.config, use_transformer_engine=True, normalization=self.config.normalization, vp_stage=vp_stage) diff --git a/src/mcore_bridge/model/rope.py b/src/mcore_bridge/model/rope.py index 5cabe42..e7db3c3 100644 --- a/src/mcore_bridge/model/rope.py +++ b/src/mcore_bridge/model/rope.py @@ -106,12 +106,12 @@ def _get_rope_type(rope_scaling: Optional[Dict[str, Any]]): return rope_type -def get_rope_inv_freq(config, seq_len=None): +def get_rope_inv_freq(config, seq_len=None, **kwargs): from transformers.modeling_rope_utils import ROPE_INIT_FUNCTIONS ROPE_INIT_FUNCTIONS.update(EXTENDED_ROPE_INIT_FUNCTIONS) dummy_config = _get_dummy_config(config) rope_init_fn = ROPE_INIT_FUNCTIONS[_get_rope_type(config.rope_scaling)] - inv_freq, attention_scaling = rope_init_fn(dummy_config, 'cpu', seq_len=seq_len) + inv_freq, attention_scaling = rope_init_fn(dummy_config, 'cpu', seq_len=seq_len, **kwargs) if attention_scaling is None: attention_scaling = 1. return inv_freq, attention_scaling From 32106178dc2476f09e2589262b8c5b3a28b70dce Mon Sep 17 00:00:00 2001 From: Jintao Huang Date: Thu, 30 Apr 2026 11:30:56 +0800 Subject: [PATCH 06/26] update --- src/mcore_bridge/model/mm_gpts/gemma4.py | 39 +++++++++++++++++++----- 1 file changed, 31 insertions(+), 8 deletions(-) diff --git a/src/mcore_bridge/model/mm_gpts/gemma4.py b/src/mcore_bridge/model/mm_gpts/gemma4.py index 2d656dd..b2a7fde 100644 --- a/src/mcore_bridge/model/mm_gpts/gemma4.py +++ b/src/mcore_bridge/model/mm_gpts/gemma4.py @@ -1,8 +1,12 @@ # Copyright (c) ModelScope Contributors. All rights reserved. import copy +from megatron.core.models.gpt.gpt_layer_specs import get_gpt_decoder_block_spec +from megatron.core.transformer.attention import SelfAttention, SelfAttentionSubmodules from transformers import AutoModel, PretrainedConfig +from typing import Optional -from mcore_bridge.bridge import GPTBridge +from mcore_bridge.bridge import MultimodalGPTBridge +from mcore_bridge.config import ModelConfig from ..constant import ModelType from ..gpt_model import GPTModel @@ -36,12 +40,30 @@ def get_inputs_embeds(self, inputs_embeds, **kwargs): return inputs_embeds -class Gemma4Bridge(GPTBridge): +class Gemma4SelfAttention(SelfAttention): + + def __init__( + self, + config: ModelConfig, + submodules: SelfAttentionSubmodules, + layer_number: int, + *args, + **kwargs, + ): + text_config = config.hf_config.text_config + super().__init__(config, submodules, layer_number, *args, **kwargs) + + +class Gemma4Bridge(MultimodalGPTBridge): pass class Gemma4TextGPTModel(GPTModel): + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + print() + def _set_inv_freq(self): rope_scaling = self.config.rope_scaling self.config.rope_scaling = rope_scaling['sliding_attention'] @@ -68,12 +90,13 @@ class Gemma4GPTModel(MultimodalGPTModel): class Gemma4Loader(ModelLoader): model_cls = Gemma4GPTModel - # def get_transformer_layer_spec(self, vp_stage: Optional[int] = None): - # layer_specs = get_gpt_decoder_block_spec( - # self.config, use_transformer_engine=True, normalization=self.config.normalization, vp_stage=vp_stage) - # for layer_spec in layer_specs.layer_specs: - # pass - # return layer_specs + + def get_transformer_layer_spec(self, vp_stage: Optional[int] = None): + layer_specs = get_gpt_decoder_block_spec( + self.config, use_transformer_engine=True, normalization=self.config.normalization, vp_stage=vp_stage) + for layer_spec in layer_specs.layer_specs: + layer_spec.submodules.self_attention.module = Gemma4SelfAttention + return layer_specs register_model( From 5b4e118bc804475f58a057a7805f7ff6f312ca4d Mon Sep 17 00:00:00 2001 From: Jintao Huang Date: Mon, 4 May 2026 18:15:41 +0800 Subject: [PATCH 07/26] update --- src/mcore_bridge/model/modules/__init__.py | 1 + .../model/modules/transformer_layer.py | 30 +++++++++++++++++++ src/mcore_bridge/patcher.py | 30 ------------------- 3 files changed, 31 insertions(+), 30 deletions(-) create mode 100644 src/mcore_bridge/model/modules/transformer_layer.py diff --git a/src/mcore_bridge/model/modules/__init__.py b/src/mcore_bridge/model/modules/__init__.py index 6fd1ac7..eff1bd6 100644 --- a/src/mcore_bridge/model/modules/__init__.py +++ b/src/mcore_bridge/model/modules/__init__.py @@ -2,3 +2,4 @@ from .gated_delta_net import GatedDeltaNet from .gated_self_attention import GatedSelfAttention from .mtp_layer import MultiTokenPredictionLayer +from .transformer_layer import CustomTransformerLayer diff --git a/src/mcore_bridge/model/modules/transformer_layer.py b/src/mcore_bridge/model/modules/transformer_layer.py new file mode 100644 index 0000000..55aa952 --- /dev/null +++ b/src/mcore_bridge/model/modules/transformer_layer.py @@ -0,0 +1,30 @@ +import megatron.core +from megatron.core.transformer import TransformerLayer +from packaging import version + +mcore_013 = version.parse(megatron.core.__version__) >= version.parse('0.13.0rc0') + + +class CustomTransformerLayer(TransformerLayer): + + def forward(self, *args, **kwargs): + """ + Perform a forward pass through the transformer layer. + + This method calls the core computation of a transformer layer, including + self-attention, cross-attention (if applicable), and feed-forward operations. + """ + if not mcore_013: + return super().forward(self, *args, **kwargs) + hidden_states, context = self._forward_attention(*args, **kwargs) + mlp_padding_free = self.config.mlp_padding_free and 'attention_mask' in kwargs + mask = None + if mlp_padding_free and hidden_states.shape[1] > 1: + mask = ((~kwargs['attention_mask']).sum(dim=(1, 2)) > 0).t() + hidden_states = hidden_states[mask][:, None] + output = self._forward_mlp(hidden_states, kwargs.get('inference_context', None)) + if mask is not None: + new_output = hidden_states.new_zeros((*mask.shape, output.shape[-1])) + new_output[mask] = output.squeeze(1) + output = new_output + return output, context diff --git a/src/mcore_bridge/patcher.py b/src/mcore_bridge/patcher.py index 15280b5..527b0c5 100644 --- a/src/mcore_bridge/patcher.py +++ b/src/mcore_bridge/patcher.py @@ -13,7 +13,6 @@ from megatron.core.tensor_parallel.mappings import (gather_from_sequence_parallel_region, gather_from_tensor_model_parallel_region, scatter_to_sequence_parallel_region) -from megatron.core.transformer import TransformerLayer from megatron.core.transformer.multi_latent_attention import MLASelfAttention, MultiLatentAttention from megatron.core.transformer.multi_token_prediction import MultiTokenPredictionBlock, get_mtp_layer_offset from megatron.core.utils import deprecate_inference_params @@ -413,34 +412,6 @@ def sharded_state_dict( peft_module.OriginModulesToSaveWrapper = OriginModulesToSaveWrapper -def _patch_TransformerLayer(): - _origin_forward = TransformerLayer.forward - - def forward(self, *_args, **kwargs): - """ - Perform a forward pass through the transformer layer. - - This method calls the core computation of a transformer layer, including - self-attention, cross-attention (if applicable), and feed-forward operations. - """ - if not mcore_013: - return _origin_forward(self, *_args, **kwargs) - hidden_states, context = self._forward_attention(*_args, **kwargs) - mlp_padding_free = self.config.mlp_padding_free and 'attention_mask' in kwargs - mask = None - if mlp_padding_free and hidden_states.shape[1] > 1: - mask = ((~kwargs['attention_mask']).sum(dim=(1, 2)) > 0).t() - hidden_states = hidden_states[mask][:, None] - output = self._forward_mlp(hidden_states, kwargs.get('inference_context', None)) - if mask is not None: - new_output = hidden_states.new_zeros((*mask.shape, output.shape[-1])) - new_output[mask] = output.squeeze(1) - output = new_output - return output, context - - TransformerLayer.forward = forward - - def _patch_TELinear(): def __repr__(self): @@ -769,7 +740,6 @@ def apply_patch(): # patch module _patch_mla_attention() _patch_TEGroupedLinear() - _patch_TransformerLayer() _patch_TELinear() _patch_mrope() _patch_mtp() From d1d22462c3e56e89ac34f06aafeb484243cfd78a Mon Sep 17 00:00:00 2001 From: Jintao Huang Date: Tue, 5 May 2026 00:07:45 +0800 Subject: [PATCH 08/26] update --- src/mcore_bridge/model/mm_gpts/gemma4.py | 23 ++++++++++++++++++++++- 1 file changed, 22 insertions(+), 1 deletion(-) diff --git a/src/mcore_bridge/model/mm_gpts/gemma4.py b/src/mcore_bridge/model/mm_gpts/gemma4.py index b2a7fde..5b09745 100644 --- a/src/mcore_bridge/model/mm_gpts/gemma4.py +++ b/src/mcore_bridge/model/mm_gpts/gemma4.py @@ -4,7 +4,7 @@ from megatron.core.transformer.attention import SelfAttention, SelfAttentionSubmodules from transformers import AutoModel, PretrainedConfig from typing import Optional - +from megatron.core.transformer.mlp import MLP from mcore_bridge.bridge import MultimodalGPTBridge from mcore_bridge.config import ModelConfig @@ -51,7 +51,27 @@ def __init__( **kwargs, ): text_config = config.hf_config.text_config + self.is_sliding = text_config.layer_types[layer_number - 1] == 'sliding_attention' + self.sliding_window = text_config.sliding_window if self.is_sliding else None + kv_channels = config.kv_channels + config.kv_channels = text_config.global_head_dim if not self.is_sliding and text_config.global_head_dim else text_config.head_dim super().__init__(config, submodules, layer_number, *args, **kwargs) + config.kv_channels = kv_channels + +class Gemma4MLP(MLP): + def __init__( + self, + config: ModelConfig, + submodules: SelfAttentionSubmodules, + *args, + **kwargs, + ): + text_config = config.hf_config.text_config + self.enable_moe_block = text_config.enable_moe_block + first_kv_shared_layer_idx = text_config.num_hidden_layers - text_config.num_kv_shared_layers + is_kv_shared_layer = layer_idx >= first_kv_shared_layer_idx > 0 + use_double_wide_mlp = config.use_double_wide_mlp and is_kv_shared_layer + super().__init__(config, submodules, *args, **kwargs) class Gemma4Bridge(MultimodalGPTBridge): @@ -96,6 +116,7 @@ def get_transformer_layer_spec(self, vp_stage: Optional[int] = None): self.config, use_transformer_engine=True, normalization=self.config.normalization, vp_stage=vp_stage) for layer_spec in layer_specs.layer_specs: layer_spec.submodules.self_attention.module = Gemma4SelfAttention + layer_spec.submodules.mlp.module = Gemma4MLP return layer_specs From 14b164489b9228b73276b3f6b3560984e483baac Mon Sep 17 00:00:00 2001 From: Jintao Huang Date: Tue, 5 May 2026 12:19:34 +0800 Subject: [PATCH 09/26] fix --- src/mcore_bridge/model/mm_gpts/gemma4.py | 11 +- .../model/modules/transformer_layer.py | 208 +++++++++++++++++- src/mcore_bridge/model/register.py | 6 +- 3 files changed, 221 insertions(+), 4 deletions(-) diff --git a/src/mcore_bridge/model/mm_gpts/gemma4.py b/src/mcore_bridge/model/mm_gpts/gemma4.py index 5b09745..8cba526 100644 --- a/src/mcore_bridge/model/mm_gpts/gemma4.py +++ b/src/mcore_bridge/model/mm_gpts/gemma4.py @@ -2,9 +2,10 @@ import copy from megatron.core.models.gpt.gpt_layer_specs import get_gpt_decoder_block_spec from megatron.core.transformer.attention import SelfAttention, SelfAttentionSubmodules +from megatron.core.transformer.mlp import MLP from transformers import AutoModel, PretrainedConfig from typing import Optional -from megatron.core.transformer.mlp import MLP + from mcore_bridge.bridge import MultimodalGPTBridge from mcore_bridge.config import ModelConfig @@ -54,18 +55,24 @@ def __init__( self.is_sliding = text_config.layer_types[layer_number - 1] == 'sliding_attention' self.sliding_window = text_config.sliding_window if self.is_sliding else None kv_channels = config.kv_channels - config.kv_channels = text_config.global_head_dim if not self.is_sliding and text_config.global_head_dim else text_config.head_dim + config.kv_channels = ( + text_config.global_head_dim if not self.is_sliding and text_config.global_head_dim else text_config.head_dim + ) super().__init__(config, submodules, layer_number, *args, **kwargs) config.kv_channels = kv_channels + class Gemma4MLP(MLP): + def __init__( self, config: ModelConfig, submodules: SelfAttentionSubmodules, + layer_number: int, *args, **kwargs, ): + self.layer_number = layer_number text_config = config.hf_config.text_config self.enable_moe_block = text_config.enable_moe_block first_kv_shared_layer_idx = text_config.num_hidden_layers - text_config.num_kv_shared_layers diff --git a/src/mcore_bridge/model/modules/transformer_layer.py b/src/mcore_bridge/model/modules/transformer_layer.py index 55aa952..83c42e3 100644 --- a/src/mcore_bridge/model/modules/transformer_layer.py +++ b/src/mcore_bridge/model/modules/transformer_layer.py @@ -1,12 +1,218 @@ import megatron.core -from megatron.core.transformer import TransformerLayer +import torch +from megatron.core.process_groups_config import ProcessGroupCollection +from megatron.core.transformer.enums import CudaGraphScope, LayerType +from megatron.core.transformer.identity_op import IdentityOp +from megatron.core.transformer.mlp import MLP +from megatron.core.transformer.spec_utils import ModuleSpec, build_module +from megatron.core.transformer.transformer_config import TransformerConfig +from megatron.core.transformer.transformer_layer import (TransformerLayer, TransformerLayerSubmodules, + get_transformer_layer_offset) +from megatron.core.utils import get_pg_rank from packaging import version +from typing import Optional + +from mcore_bridge.utils import get_logger mcore_013 = version.parse(megatron.core.__version__) >= version.parse('0.13.0rc0') +logger = get_logger() + class CustomTransformerLayer(TransformerLayer): + def __init__( + self, + config: TransformerConfig, + submodules: TransformerLayerSubmodules, + layer_number: int = 1, + hidden_dropout: Optional[float] = None, + pg_collection: Optional[ProcessGroupCollection] = None, + vp_stage: Optional[int] = None, + is_mtp_layer: bool = False, + add_layer_offset: bool = True, + pp_layer_offset: Optional[int] = None, + ): + self.submodules_config = submodules + super().__init__(config=config, vp_stage=vp_stage) + + if pg_collection is None: + pg_collection = ProcessGroupCollection.use_mpu_process_groups() + self.pg_collection = pg_collection + self.tp_group = pg_collection.tp + + # MTP inner layers use their own layer numbering (starting from 1 within each MTP depth), + # so they should NOT add the decoder layer offset. The router.py handles MTP layer + # numbering separately by adding config.num_layers to distinguish MTP layers from decoder + # layers in the aux loss tracker. + # + # When add_layer_offset is False, the caller has already included the correct offset + # in layer_number (e.g. when using --hybrid-layer-pattern with fVPP). + if is_mtp_layer or not add_layer_offset: + self.layer_number = layer_number + else: + self.layer_number = layer_number + get_transformer_layer_offset(self.config, vp_stage, + get_pg_rank(pg_collection.pp)) + self.hidden_dropout = config.hidden_dropout if hidden_dropout is None else hidden_dropout + self.is_mtp_layer = is_mtp_layer + + # [Module 1: Input Layernorm] Optional Layernorm on the input data + # TODO: add pytorch only layernorm + self.input_layernorm = submodules.input_layernorm( + config=self.config, + hidden_size=self.config.hidden_size, + eps=self.config.layernorm_epsilon, + ) + + attention_optional_kwargs = {} + if config.context_parallel_size > 1 and config.cp_comm_type is not None: + if isinstance(config.cp_comm_type, list): + # layer_number is 1-indexed, so we need to subtract 1 to get the correct index + attention_optional_kwargs['cp_comm_type'] = config.cp_comm_type[self.layer_number - 1] + else: + attention_optional_kwargs['cp_comm_type'] = config.cp_comm_type + + attention_optional_kwargs['pg_collection'] = pg_collection + if pp_layer_offset is not None: + attention_optional_kwargs['pp_layer_offset'] = pp_layer_offset + + # [Module 2: SelfAttention] + self.self_attention = build_module( + submodules.self_attention, + config=self.config, + layer_number=self.layer_number, + **attention_optional_kwargs, + ) + + # [Module 3: BiasDropoutFusion] + self.self_attn_bda = build_module(submodules.self_attn_bda) + + # [Module 4: Post SelfAttention] Optional Layernorm after self-attn + self.pre_cross_attn_layernorm = submodules.pre_cross_attn_layernorm( + config=self.config, + hidden_size=self.config.hidden_size, + eps=self.config.layernorm_epsilon, + ) + + # [Module 5: CrossAttention] + self.cross_attention = build_module( + submodules.cross_attention, + config=self.config, + layer_number=self.layer_number, + **attention_optional_kwargs, + ) + + # [Module 6: BiasDropoutFusion] + self.cross_attn_bda = build_module(submodules.cross_attn_bda, config=self.config) + + # [Module 7: Pre MLP] Optional Layernorm before MLP + self.pre_mlp_layernorm = submodules.pre_mlp_layernorm( + config=self.config, + hidden_size=self.config.hidden_size, + eps=self.config.layernorm_epsilon, + ) + # [Module 8: MLP block] + additional_mlp_kwargs = {} + # import here to avoid circular import + from megatron.core.extensions.transformer_engine import TEFusedMLP + from megatron.core.transformer.moe.experts import SequentialMLP, TEGroupedMLP + from megatron.core.transformer.moe.moe_layer import MoELayer + + # MLP expects tp_group but MoELayer expects pg_collection to be passed in. + # We can change MLP to accept pg_collection but it makes the logic implicit + # The conditional below is to make the logic explicit + # if submodules.mlp is not a ModuleSpec,we dont have to handle passing additional kwargs + if isinstance(submodules.mlp, ModuleSpec): + if submodules.mlp.module in (MoELayer, TEGroupedMLP, SequentialMLP): + additional_mlp_kwargs['pg_collection'] = pg_collection + # Pass is_mtp_layer flag to MoELayer to distinguish MTP MoE layers. + if submodules.mlp.module == MoELayer: + additional_mlp_kwargs['is_mtp_layer'] = self.is_mtp_layer + elif submodules.mlp.module == MLP: + assert hasattr(pg_collection, 'tp'), 'TP process group is required for MLP in TransformerLayer' + additional_mlp_kwargs['tp_group'] = pg_collection.tp + elif TEFusedMLP is not None and submodules.mlp.module == TEFusedMLP: + assert hasattr(pg_collection, 'tp'), 'TP process group is required for TEFusedMLP in TransformerLayer' + additional_mlp_kwargs['tp_group'] = pg_collection.tp + else: + logger.warning_once(f"Unknown MLP type: {type(submodules.mlp)}. Using default kwargs.") + self.mlp = build_module(submodules.mlp, config=self.config, **additional_mlp_kwargs) + if hasattr(self.mlp, 'set_layer_number'): + self.mlp.set_layer_number(self.layer_number) + + # [Module 9: BiasDropoutFusion] + self.mlp_bda = build_module(submodules.mlp_bda) + + self.is_moe_layer = isinstance(self.mlp, MoELayer) + + self.recompute_input_layernorm = False + self.recompute_pre_mlp_layernorm = False + self.recompute_mlp = False + if self.config.recompute_granularity == 'selective': + assert self.config.recompute_modules is not None + if 'layernorm' in self.config.recompute_modules: + if not isinstance(self.input_layernorm, IdentityOp): + self.recompute_input_layernorm = True + if self.config.fp8 or self.config.fp4: + self.self_attention.set_for_recompute_input_layernorm() + + def can_recompute_pre_mlp_layernorm_for_cudagraph(): + if (not self.is_moe_layer or CudaGraphScope.moe_router not in self.config.cuda_graph_scope + or self.config.cuda_graph_impl == 'local'): + # Not a MoE layer, or not capturing the router part. + return True + if (self.config.moe_shared_expert_intermediate_size is not None + and self.config.moe_shared_expert_overlap): + # If shared expert overlap is used, we cannot make the pre-mlp layernorm + # recomputation, because the shared expert takes the layernorm output as + # input, and it is outside of the CUDA graph scope. + logger.warning( + 'pre_mlp_layernorm recompute is not supported with moe router ' + 'cudagraph + shared expert overlap. Disabling pre_mlp_layernorm ' + 'recompute.', ) + return False + if CudaGraphScope.moe_preprocess in self.config.cuda_graph_scope and ( + self.config.moe_token_dispatcher_type == 'alltoall' or self.config.moe_latent_size): + # Only when capturing the preprocess part and using alltoall token + # dispatcher or latent MoE can we make the pre-mlp layernorm recomputation. + # Because in other cases the layernorm output returns directly as one of the + # outputs of the cudagraph, which will be allocated a static buffer, thus + # not able to be released. + return True + logger.warning( + 'pre_mlp_layernorm recompute is only supported with moe router + ' + 'preprocess cudagraph will alltoall token dispatcher or latent MoE. ' + 'Disabling pre_mlp_layernorm recompute.', ) + return False + + if (not isinstance(self.pre_mlp_layernorm, IdentityOp) + and can_recompute_pre_mlp_layernorm_for_cudagraph()): + self.recompute_pre_mlp_layernorm = True + if self.config.fp8 or self.config.fp4: + if isinstance(self.mlp, MoELayer): + self.mlp.set_for_recompute_pre_mlp_layernorm() + else: + from megatron.core.extensions.transformer_engine import set_save_original_input + + set_save_original_input(self.mlp.linear_fc1) + if 'mlp' in self.config.recompute_modules: + if not self.is_moe_layer: + self.recompute_mlp = True + self.offload_attn_norm = ( + self.config.fine_grained_activation_offloading and 'attn_norm' in self.config.offload_modules + and not isinstance(self.input_layernorm, IdentityOp)) + self.offload_mlp_norm = ( + self.config.fine_grained_activation_offloading and 'mlp_norm' in self.config.offload_modules + and not isinstance(self.pre_mlp_layernorm, IdentityOp)) + + # @jcasper how should we handle nvfuser? + # Set bias+dropout+add fusion grad_enable execution handler. + # TORCH_MAJOR = int(torch.__version__.split('.')[0]) + # TORCH_MINOR = int(torch.__version__.split('.')[1]) + # use_nvfuser = TORCH_MAJOR > 1 or (TORCH_MAJOR == 1 and TORCH_MINOR >= 10) + # self.bias_dropout_add_exec_handler = nullcontext if use_nvfuser else torch.enable_grad + self.bias_dropout_add_exec_handler = torch.enable_grad + def forward(self, *args, **kwargs): """ Perform a forward pass through the transformer layer. diff --git a/src/mcore_bridge/model/register.py b/src/mcore_bridge/model/register.py index 0e67e90..9be5b6f 100644 --- a/src/mcore_bridge/model/register.py +++ b/src/mcore_bridge/model/register.py @@ -15,7 +15,7 @@ from mcore_bridge.config import ModelConfig from mcore_bridge.utils import get_logger -from .modules import MultiTokenPredictionLayer +from .modules import CustomTransformerLayer, MultiTokenPredictionLayer if TYPE_CHECKING: from .gpt_model import GPTModel @@ -138,6 +138,10 @@ def _set_shared_expert_gate(self, transformer_layer_spec): if hasattr(layer_spec.submodules.mlp.submodules, 'shared_experts'): layer_spec.submodules.mlp.submodules.shared_experts.params = {'gate': True} + def _set_custom_layer(self, transformer_layer_spec): + pass + # CustomTransformerLayer + def build_model( self, pre_process=True, From 196a58fda42452eebe2eabee2f2545f0099e2122 Mon Sep 17 00:00:00 2001 From: Jintao Huang Date: Tue, 5 May 2026 13:19:36 +0800 Subject: [PATCH 10/26] update --- src/mcore_bridge/model/gpts/glm4.py | 10 ++-- src/mcore_bridge/model/gpts/minimax_m2.py | 7 +-- src/mcore_bridge/model/mm_gpts/gemma4.py | 4 +- .../model/modules/transformer_layer.py | 45 +++++++++++++----- src/mcore_bridge/model/register.py | 46 ++++++------------- 5 files changed, 58 insertions(+), 54 deletions(-) diff --git a/src/mcore_bridge/model/gpts/glm4.py b/src/mcore_bridge/model/gpts/glm4.py index 5c7a8ca..861f94d 100644 --- a/src/mcore_bridge/model/gpts/glm4.py +++ b/src/mcore_bridge/model/gpts/glm4.py @@ -91,11 +91,11 @@ def _set_layer_attn(self, mg_layer, hf_state_dict, layer_idx: int, to_mcore: boo class Glm4Loader(ModelLoader): def get_transformer_layer_spec(self, vp_stage: Optional[int] = None): - layer_spec = self._get_transformer_layer_spec() - layer_spec.submodules.self_attention.module = Glm4SelfAttention - layer_spec.submodules.mlp.module = Glm4MLP - transformer_layer.MLP = Glm4MLP # patch - return layer_spec + transformer_layer_spec = super().get_transformer_layer_spec(vp_stage) + for layer_spec in transformer_layer_spec.layer_specs: + layer_spec.submodules.self_attention.module = Glm4SelfAttention + layer_spec.submodules.mlp.module = Glm4MLP + return transformer_layer_spec register_model(ModelMeta( diff --git a/src/mcore_bridge/model/gpts/minimax_m2.py b/src/mcore_bridge/model/gpts/minimax_m2.py index 81b11b6..c03f803 100644 --- a/src/mcore_bridge/model/gpts/minimax_m2.py +++ b/src/mcore_bridge/model/gpts/minimax_m2.py @@ -95,9 +95,10 @@ def _set_moe_state( class MinimaxM2Loader(ModelLoader): def get_transformer_layer_spec(self, vp_stage: Optional[int] = None): - layer_spec = self._get_transformer_layer_spec() - layer_spec.submodules.self_attention.module = MinimaxM2SelfAttention - return layer_spec + transformer_layer_spec = super().get_transformer_layer_spec(vp_stage) + for layer_spec in transformer_layer_spec.layer_specs: + layer_spec.submodules.self_attention.module = MinimaxM2SelfAttention + return transformer_layer_spec register_model(ModelMeta( diff --git a/src/mcore_bridge/model/mm_gpts/gemma4.py b/src/mcore_bridge/model/mm_gpts/gemma4.py index 8cba526..8bfc0c2 100644 --- a/src/mcore_bridge/model/mm_gpts/gemma4.py +++ b/src/mcore_bridge/model/mm_gpts/gemma4.py @@ -56,8 +56,8 @@ def __init__( self.sliding_window = text_config.sliding_window if self.is_sliding else None kv_channels = config.kv_channels config.kv_channels = ( - text_config.global_head_dim if not self.is_sliding and text_config.global_head_dim else text_config.head_dim - ) + text_config.global_head_dim + if not self.is_sliding and text_config.global_head_dim else text_config.head_dim) super().__init__(config, submodules, layer_number, *args, **kwargs) config.kv_channels = kv_channels diff --git a/src/mcore_bridge/model/modules/transformer_layer.py b/src/mcore_bridge/model/modules/transformer_layer.py index 83c42e3..4fe0a82 100644 --- a/src/mcore_bridge/model/modules/transformer_layer.py +++ b/src/mcore_bridge/model/modules/transformer_layer.py @@ -1,7 +1,8 @@ +import enum +import inspect import megatron.core import torch from megatron.core.process_groups_config import ProcessGroupCollection -from megatron.core.transformer.enums import CudaGraphScope, LayerType from megatron.core.transformer.identity_op import IdentityOp from megatron.core.transformer.mlp import MLP from megatron.core.transformer.spec_utils import ModuleSpec, build_module @@ -14,6 +15,22 @@ from mcore_bridge.utils import get_logger +try: + from megatron.core.transformer.enums import CudaGraphScope +except ImportError: + + class CudaGraphScope(enum.Enum): + """Cuda Graph Scope - defines which parts of the model to capture.""" + + full_iteration = 1 # Captures the entire training/inference iteration + attn = 2 # Captures attention layers + mlp = 3 # Captures MLP layers (dense layers only) + moe = 4 # Captures MoE layers (drop-and-pad MoE layers only) + moe_router = 5 # Captures MoE router part + moe_preprocess = 6 # Captures MoE preprocessing part (requires moe_router) + mamba = 7 # Captures Mamba layers + + mcore_013 = version.parse(megatron.core.__version__) >= version.parse('0.13.0rc0') logger = get_logger() @@ -34,7 +51,7 @@ def __init__( pp_layer_offset: Optional[int] = None, ): self.submodules_config = submodules - super().__init__(config=config, vp_stage=vp_stage) + super(TransformerLayer, self).__init__(config=config, vp_stage=vp_stage) if pg_collection is None: pg_collection = ProcessGroupCollection.use_mpu_process_groups() @@ -118,6 +135,9 @@ def __init__( from megatron.core.transformer.moe.experts import SequentialMLP, TEGroupedMLP from megatron.core.transformer.moe.moe_layer import MoELayer + from mcore_bridge.model.gpts.glm4 import Glm4MLP + from mcore_bridge.model.mm_gpts.gemma4 import Gemma4MLP + # MLP expects tp_group but MoELayer expects pg_collection to be passed in. # We can change MLP to accept pg_collection but it makes the logic implicit # The conditional below is to make the logic explicit @@ -126,16 +146,18 @@ def __init__( if submodules.mlp.module in (MoELayer, TEGroupedMLP, SequentialMLP): additional_mlp_kwargs['pg_collection'] = pg_collection # Pass is_mtp_layer flag to MoELayer to distinguish MTP MoE layers. - if submodules.mlp.module == MoELayer: + if submodules.mlp.module == MoELayer and 'is_mtp_layer' in inspect.signature(MoELayer).parameters: additional_mlp_kwargs['is_mtp_layer'] = self.is_mtp_layer - elif submodules.mlp.module == MLP: + elif submodules.mlp.module in (MLP, Glm4MLP): assert hasattr(pg_collection, 'tp'), 'TP process group is required for MLP in TransformerLayer' additional_mlp_kwargs['tp_group'] = pg_collection.tp + elif submodules.mlp.module == Gemma4MLP: + additional_mlp_kwargs['layer_number'] = layer_number elif TEFusedMLP is not None and submodules.mlp.module == TEFusedMLP: assert hasattr(pg_collection, 'tp'), 'TP process group is required for TEFusedMLP in TransformerLayer' additional_mlp_kwargs['tp_group'] = pg_collection.tp else: - logger.warning_once(f"Unknown MLP type: {type(submodules.mlp)}. Using default kwargs.") + logger.warning_once(f'Unknown MLP type: {type(submodules.mlp)}. Using default kwargs.') self.mlp = build_module(submodules.mlp, config=self.config, **additional_mlp_kwargs) if hasattr(self.mlp, 'set_layer_number'): self.mlp.set_layer_number(self.layer_number) @@ -198,12 +220,13 @@ def can_recompute_pre_mlp_layernorm_for_cudagraph(): if 'mlp' in self.config.recompute_modules: if not self.is_moe_layer: self.recompute_mlp = True - self.offload_attn_norm = ( - self.config.fine_grained_activation_offloading and 'attn_norm' in self.config.offload_modules - and not isinstance(self.input_layernorm, IdentityOp)) - self.offload_mlp_norm = ( - self.config.fine_grained_activation_offloading and 'mlp_norm' in self.config.offload_modules - and not isinstance(self.pre_mlp_layernorm, IdentityOp)) + if hasattr(self.config, 'fine_grained_activation_offloading'): + self.offload_attn_norm = ( + self.config.fine_grained_activation_offloading and 'attn_norm' in self.config.offload_modules + and not isinstance(self.input_layernorm, IdentityOp)) + self.offload_mlp_norm = ( + self.config.fine_grained_activation_offloading and 'mlp_norm' in self.config.offload_modules + and not isinstance(self.pre_mlp_layernorm, IdentityOp)) # @jcasper how should we handle nvfuser? # Set bias+dropout+add fusion grad_enable execution handler. diff --git a/src/mcore_bridge/model/register.py b/src/mcore_bridge/model/register.py index 9be5b6f..e8eef7d 100644 --- a/src/mcore_bridge/model/register.py +++ b/src/mcore_bridge/model/register.py @@ -90,41 +90,20 @@ def _replace_spec_dsa(self, layer_spec): layer_spec.submodules.self_attention = dsa_spec def get_transformer_layer_spec(self, vp_stage: Optional[int] = None): - if self.config.num_moe_experts: - transformer_layer_spec = get_gpt_decoder_block_spec( - self.config, - use_transformer_engine=True, - normalization=self.config.normalization, - qk_l2_norm=self.config.qk_l2_norm, - vp_stage=vp_stage) - if self.config.experimental_attention_variant == 'dsa': - for layer_spec in transformer_layer_spec.layer_specs: - self._replace_spec_dsa(layer_spec) - else: - transformer_layer_spec = self._get_transformer_layer_spec() - return transformer_layer_spec - - def _get_transformer_layer_spec(self): - config = self.config - transformer_layer_spec = get_gpt_layer_with_transformer_engine_spec( - config.num_moe_experts, - config.moe_grouped_gemm, - config.qk_layernorm, - config.multi_latent_attention, - qk_l2_norm=config.qk_l2_norm, - ) + transformer_layer_spec = get_gpt_decoder_block_spec( + self.config, + use_transformer_engine=True, + normalization=self.config.normalization, + qk_l2_norm=self.config.qk_l2_norm, + vp_stage=vp_stage) + if self.config.experimental_attention_variant == 'dsa': + for layer_spec in transformer_layer_spec.layer_specs: + self._replace_spec_dsa(layer_spec) return transformer_layer_spec def get_mtp_block_spec(self, transformer_layer_spec, vp_stage: Optional[int] = None): - if hasattr(transformer_layer_spec, 'layer_specs') and len(transformer_layer_spec.layer_specs) == 0: - # Get the decoder layer spec explicitly if no decoder layer in the last stage, - # Only happens with block spec (TransformerBlockSubmodules) when using MoE. - # TODO: remove - transformer_layer_spec_for_mtp = self._get_transformer_layer_spec() - else: - transformer_layer_spec_for_mtp = transformer_layer_spec mtp_block_spec = get_gpt_mtp_block_spec( - self.config, transformer_layer_spec_for_mtp, use_transformer_engine=True, vp_stage=vp_stage) + self.config, transformer_layer_spec, use_transformer_engine=True, vp_stage=vp_stage) if mtp_block_spec is not None: for layer_spec in mtp_block_spec.layer_specs: layer_spec.module = MultiTokenPredictionLayer @@ -139,8 +118,8 @@ def _set_shared_expert_gate(self, transformer_layer_spec): layer_spec.submodules.mlp.submodules.shared_experts.params = {'gate': True} def _set_custom_layer(self, transformer_layer_spec): - pass - # CustomTransformerLayer + for layer_spec in transformer_layer_spec.layer_specs: + layer_spec.module = CustomTransformerLayer def build_model( self, @@ -150,6 +129,7 @@ def build_model( ) -> Union['GPTModel', 'MultimodalGPTModel']: transformer_layer_spec = self.get_transformer_layer_spec(vp_stage=vp_stage) self._set_shared_expert_gate(transformer_layer_spec) + self._set_custom_layer(transformer_layer_spec) mtp_block_spec = None if self.config.mtp_num_layers is not None: mtp_block_spec = self.get_mtp_block_spec(transformer_layer_spec, vp_stage=vp_stage) From 68e33a7dc04b3b8a746993802f98b5437ba705ae Mon Sep 17 00:00:00 2001 From: Jintao Huang Date: Tue, 5 May 2026 13:31:46 +0800 Subject: [PATCH 11/26] update --- src/mcore_bridge/model/modules/transformer_layer.py | 6 ------ 1 file changed, 6 deletions(-) diff --git a/src/mcore_bridge/model/modules/transformer_layer.py b/src/mcore_bridge/model/modules/transformer_layer.py index 4fe0a82..83dcf5e 100644 --- a/src/mcore_bridge/model/modules/transformer_layer.py +++ b/src/mcore_bridge/model/modules/transformer_layer.py @@ -1,6 +1,5 @@ import enum import inspect -import megatron.core import torch from megatron.core.process_groups_config import ProcessGroupCollection from megatron.core.transformer.identity_op import IdentityOp @@ -10,7 +9,6 @@ from megatron.core.transformer.transformer_layer import (TransformerLayer, TransformerLayerSubmodules, get_transformer_layer_offset) from megatron.core.utils import get_pg_rank -from packaging import version from typing import Optional from mcore_bridge.utils import get_logger @@ -31,8 +29,6 @@ class CudaGraphScope(enum.Enum): mamba = 7 # Captures Mamba layers -mcore_013 = version.parse(megatron.core.__version__) >= version.parse('0.13.0rc0') - logger = get_logger() @@ -243,8 +239,6 @@ def forward(self, *args, **kwargs): This method calls the core computation of a transformer layer, including self-attention, cross-attention (if applicable), and feed-forward operations. """ - if not mcore_013: - return super().forward(self, *args, **kwargs) hidden_states, context = self._forward_attention(*args, **kwargs) mlp_padding_free = self.config.mlp_padding_free and 'attention_mask' in kwargs mask = None From 44ddaec8fec777dbda3b627c5718694d7d1bb9a8 Mon Sep 17 00:00:00 2001 From: Jintao Huang Date: Tue, 5 May 2026 20:31:15 +0800 Subject: [PATCH 12/26] update --- .../model/modules/transformer_layer.py | 22 ++++++++++++++++++- src/mcore_bridge/model/register.py | 4 +--- 2 files changed, 22 insertions(+), 4 deletions(-) diff --git a/src/mcore_bridge/model/modules/transformer_layer.py b/src/mcore_bridge/model/modules/transformer_layer.py index 83dcf5e..11c940b 100644 --- a/src/mcore_bridge/model/modules/transformer_layer.py +++ b/src/mcore_bridge/model/modules/transformer_layer.py @@ -2,6 +2,8 @@ import inspect import torch from megatron.core.process_groups_config import ProcessGroupCollection +from megatron.core.tensor_parallel.mappings import (gather_from_sequence_parallel_region, + scatter_to_sequence_parallel_region) from megatron.core.transformer.identity_op import IdentityOp from megatron.core.transformer.mlp import MLP from megatron.core.transformer.spec_utils import ModuleSpec, build_module @@ -242,12 +244,30 @@ def forward(self, *args, **kwargs): hidden_states, context = self._forward_attention(*args, **kwargs) mlp_padding_free = self.config.mlp_padding_free and 'attention_mask' in kwargs mask = None + enable_sp = self.config.sequence_parallel and self.config.tensor_model_parallel_size > 1 + pad_size = 0 if mlp_padding_free and hidden_states.shape[1] > 1: + if enable_sp: + hidden_states = gather_from_sequence_parallel_region(hidden_states, tensor_parallel_output_grad=False) mask = ((~kwargs['attention_mask']).sum(dim=(1, 2)) > 0).t() hidden_states = hidden_states[mask][:, None] + if enable_sp: + tp_size = self.config.tensor_model_parallel_size + num_tokens = hidden_states.shape[0] + remainder = num_tokens % tp_size + if remainder != 0: + pad_size = tp_size - remainder + hidden_states = torch.nn.functional.pad(hidden_states, (0, 0, 0, 0, 0, pad_size)) + hidden_states = scatter_to_sequence_parallel_region(hidden_states) output = self._forward_mlp(hidden_states, kwargs.get('inference_context', None)) if mask is not None: - new_output = hidden_states.new_zeros((*mask.shape, output.shape[-1])) + if enable_sp: + output = gather_from_sequence_parallel_region(output, tensor_parallel_output_grad=False) + if pad_size > 0: + output = output[:-pad_size] + new_output = output.new_zeros((*mask.shape, output.shape[-1])) new_output[mask] = output.squeeze(1) output = new_output + if enable_sp: + output = scatter_to_sequence_parallel_region(output) return output, context diff --git a/src/mcore_bridge/model/register.py b/src/mcore_bridge/model/register.py index e8eef7d..15b37fe 100644 --- a/src/mcore_bridge/model/register.py +++ b/src/mcore_bridge/model/register.py @@ -4,9 +4,7 @@ from megatron.core import mpu from megatron.core.enums import ModelType from megatron.core.extensions.transformer_engine import TEGroupedLinear, TELayerNormColumnParallelLinear, TELinear -from megatron.core.models.gpt.gpt_layer_specs import (get_gpt_decoder_block_spec, - get_gpt_layer_with_transformer_engine_spec, - get_gpt_mtp_block_spec) +from megatron.core.models.gpt.gpt_layer_specs import get_gpt_decoder_block_spec, get_gpt_mtp_block_spec from packaging import version from torch import nn from typing import TYPE_CHECKING, List, Optional, Type, Union From 0de0ebba80cdb40092a6436403a0e103620b3826 Mon Sep 17 00:00:00 2001 From: Jintao Huang Date: Sat, 9 May 2026 10:37:52 +0800 Subject: [PATCH 13/26] update --- src/mcore_bridge/config/model_config.py | 5 ++--- src/mcore_bridge/config/parser.py | 2 ++ src/mcore_bridge/model/gpt_model.py | 2 +- src/mcore_bridge/model/mm_gpts/gemma4.py | 15 +++++++++++++-- tests/test_mllm.py | 7 ++++++- 5 files changed, 24 insertions(+), 7 deletions(-) diff --git a/src/mcore_bridge/config/model_config.py b/src/mcore_bridge/config/model_config.py index 5eec352..943c100 100644 --- a/src/mcore_bridge/config/model_config.py +++ b/src/mcore_bridge/config/model_config.py @@ -3,13 +3,13 @@ import os import re import torch.nn.functional as F -from dataclasses import dataclass, field +from dataclasses import dataclass from megatron.core import mpu from megatron.core.transformer import TransformerConfig from transformers import PretrainedConfig from transformers.utils import is_torch_npu_available from transformers.utils.versions import require_version -from typing import Any, Dict, List, Literal, Optional, Union +from typing import List, Literal, Optional, Union from mcore_bridge.utils import get_logger, json_parse_to_dict @@ -229,7 +229,6 @@ class ModelConfig(TransformerConfig): task_type: Literal['causal_lm', 'seq_cls', 'embedding', 'generative_reranker'] = 'causal_lm' num_labels: Optional[int] = None mlp_padding_free: bool = False - model_kwargs: Dict[str, Any] = field(default_factory=dict) _mindspeed_defaults_cache = None diff --git a/src/mcore_bridge/config/parser.py b/src/mcore_bridge/config/parser.py index 877b63c..68ef8e3 100644 --- a/src/mcore_bridge/config/parser.py +++ b/src/mcore_bridge/config/parser.py @@ -149,6 +149,8 @@ def hf_to_mcore_config(hf_config: PretrainedConfig) -> Dict[str, Any]: n_shared_experts = res.pop('n_shared_experts') elif llm_model_type in {'ernie4_5', 'ernie4_5_moe', 'glm4'}: res['rotary_interleaved'] = True + elif hf_model_type == 'gemma4': + config.qk_layernorm = True elif llm_model_type == 'gpt_oss': res['add_bias_linear'] = True res['bias_dropout_fusion'] = False diff --git a/src/mcore_bridge/model/gpt_model.py b/src/mcore_bridge/model/gpt_model.py index 86c9587..5679122 100644 --- a/src/mcore_bridge/model/gpt_model.py +++ b/src/mcore_bridge/model/gpt_model.py @@ -311,7 +311,7 @@ def forward( """ inference_context = deprecate_inference_params(inference_context, inference_params) - + # There is a difference in whether rotary_pos_emb can be fused between the decoder and MTP. decoder_input, rotary_pos_emb, decoder_rotary_pos_emb, rotary_pos_cos, rotary_pos_sin, sequence_len_offset = ( self._preprocess( input_ids=input_ids, diff --git a/src/mcore_bridge/model/mm_gpts/gemma4.py b/src/mcore_bridge/model/mm_gpts/gemma4.py index 8bfc0c2..a45fb5a 100644 --- a/src/mcore_bridge/model/mm_gpts/gemma4.py +++ b/src/mcore_bridge/model/mm_gpts/gemma4.py @@ -15,6 +15,7 @@ from ..register import ModelLoader, ModelMeta, register_model from ..rope import get_rope_inv_freq from .utils import HuggingFaceVit +from ..module import CustomTransformerLayer class Gemma4Vit(HuggingFaceVit): @@ -76,9 +77,12 @@ def __init__( text_config = config.hf_config.text_config self.enable_moe_block = text_config.enable_moe_block first_kv_shared_layer_idx = text_config.num_hidden_layers - text_config.num_kv_shared_layers - is_kv_shared_layer = layer_idx >= first_kv_shared_layer_idx > 0 - use_double_wide_mlp = config.use_double_wide_mlp and is_kv_shared_layer + is_kv_shared_layer = layer_number > first_kv_shared_layer_idx > 0 + use_double_wide_mlp = text_config.use_double_wide_mlp and is_kv_shared_layer + ffn_hidden_size = config.ffn_hidden_size + config.ffn_hidden_size = config.ffn_hidden_size * (2 if use_double_wide_mlp else 1) super().__init__(config, submodules, *args, **kwargs) + config.ffn_hidden_size = ffn_hidden_size class Gemma4Bridge(MultimodalGPTBridge): @@ -110,6 +114,9 @@ def _set_inv_freq(self): self.config.rope_scaling = rope_scaling +class Gemma4TransformerLayer(CustomTransformerLayer): + pass + class Gemma4GPTModel(MultimodalGPTModel): language_model_cls = Gemma4TextGPTModel @@ -127,6 +134,10 @@ def get_transformer_layer_spec(self, vp_stage: Optional[int] = None): return layer_specs + def _set_custom_layer(self, transformer_layer_spec): + for layer_spec in transformer_layer_spec.layer_specs: + layer_spec.module = Gemma4TransformerLayer + register_model( ModelMeta( ModelType.gemma4, diff --git a/tests/test_mllm.py b/tests/test_mllm.py index f4832f7..bacf76c 100644 --- a/tests/test_mllm.py +++ b/tests/test_mllm.py @@ -112,6 +112,10 @@ def test_llava_onevision1_5(): _test_model('lmms-lab/LLaVA-OneVision-1.5-4B-Instruct') +def test_gemma4(): + _test_model('google/gemma-4-E2B-it') + + if __name__ == '__main__': # test_qwen2_5_vl() # test_qwen2_vl() @@ -131,4 +135,5 @@ def test_llava_onevision1_5(): # test_qwen3_omni() # test_llama4() # test_qwen3_5() - test_llava_onevision1_5() + # test_llava_onevision1_5() + test_gemma4() From 2a81bf056b08fb3784c731f56d63d05c90bbee86 Mon Sep 17 00:00:00 2001 From: Jintao Huang Date: Sat, 9 May 2026 10:47:07 +0800 Subject: [PATCH 14/26] update --- src/mcore_bridge/model/mm_gpts/gemma4.py | 51 +++++++++++++++++++++--- 1 file changed, 46 insertions(+), 5 deletions(-) diff --git a/src/mcore_bridge/model/mm_gpts/gemma4.py b/src/mcore_bridge/model/mm_gpts/gemma4.py index a45fb5a..e3ef62d 100644 --- a/src/mcore_bridge/model/mm_gpts/gemma4.py +++ b/src/mcore_bridge/model/mm_gpts/gemma4.py @@ -53,14 +53,55 @@ def __init__( **kwargs, ): text_config = config.hf_config.text_config - self.is_sliding = text_config.layer_types[layer_number - 1] == 'sliding_attention' + layer_idx = layer_number - 1 + + # Layer type / sliding attention + self.layer_type = text_config.layer_types[layer_idx] + self.is_sliding = self.layer_type == 'sliding_attention' self.sliding_window = text_config.sliding_window if self.is_sliding else None - kv_channels = config.kv_channels - config.kv_channels = ( + + # Head dim: global layers may use a different head dim than sliding ones + self.head_dim = ( text_config.global_head_dim if not self.is_sliding and text_config.global_head_dim else text_config.head_dim) - super().__init__(config, submodules, layer_number, *args, **kwargs) - config.kv_channels = kv_channels + + # Alternative attention (k == v) for global layers when `attention_k_eq_v` is set + self.use_alternative_attention = ( + getattr(text_config, 'attention_k_eq_v', False) and not self.is_sliding) + num_key_value_heads = ( + text_config.num_global_key_value_heads + if self.use_alternative_attention else text_config.num_key_value_heads) + self.num_key_value_groups = text_config.num_attention_heads // num_key_value_heads + + self.is_causal = getattr(text_config, 'use_bidirectional_attention', None) != 'all' + + # Shared KV across the trailing layers + num_kv_shared_layers = getattr(text_config, 'num_kv_shared_layers', 0) + first_kv_shared_layer_idx = text_config.num_hidden_layers - num_kv_shared_layers + self.is_kv_shared_layer = layer_idx >= first_kv_shared_layer_idx > 0 + prev_layers = text_config.layer_types[:first_kv_shared_layer_idx] + if self.is_kv_shared_layer: + # For shared layers, reuse KV from the last non-shared layer of the same type + self.kv_shared_layer_index = ( + len(prev_layers) - 1 - prev_layers[::-1].index(self.layer_type)) + self.store_full_length_kv = False + else: + self.kv_shared_layer_index = None + # Non-shared layers that are the last of their type in `prev_layers` must keep full KV + self.store_full_length_kv = ( + self.layer_type in prev_layers + and layer_idx == len(prev_layers) - 1 - prev_layers[::-1].index(self.layer_type)) + + # Patch config so the underlying linear_qkv is built with the correct shapes + orig_kv_channels = config.kv_channels + orig_num_query_groups = config.num_query_groups + config.kv_channels = self.head_dim + config.num_query_groups = num_key_value_heads + try: + super().__init__(config, submodules, layer_number, *args, **kwargs) + finally: + config.kv_channels = orig_kv_channels + config.num_query_groups = orig_num_query_groups class Gemma4MLP(MLP): From 7e05d3d70ebdcdcea606b26120a6e0af4cdca8d4 Mon Sep 17 00:00:00 2001 From: Jintao Huang Date: Sat, 9 May 2026 10:51:32 +0800 Subject: [PATCH 15/26] update --- src/mcore_bridge/model/gpts/minimax_m2.py | 8 +++++--- src/mcore_bridge/model/mm_gpts/gemma4.py | 6 ++++-- src/mcore_bridge/model/modules/gated_delta_net.py | 7 +++++-- src/mcore_bridge/model/modules/mtp_layer.py | 7 +++++-- 4 files changed, 19 insertions(+), 9 deletions(-) diff --git a/src/mcore_bridge/model/gpts/minimax_m2.py b/src/mcore_bridge/model/gpts/minimax_m2.py index c03f803..7830215 100644 --- a/src/mcore_bridge/model/gpts/minimax_m2.py +++ b/src/mcore_bridge/model/gpts/minimax_m2.py @@ -27,9 +27,11 @@ def __init__( k_layernorm = submodules.k_layernorm submodules.q_layernorm = IdentityOp submodules.k_layernorm = IdentityOp - super().__init__(config, submodules, *args, **kwargs) - submodules.q_layernorm = q_layernorm - submodules.k_layernorm = k_layernorm + try: + super().__init__(config, submodules, *args, **kwargs) + finally: + submodules.q_layernorm = q_layernorm + submodules.k_layernorm = k_layernorm self.q_norm = build_module( submodules.q_layernorm, hidden_size=self.hidden_size_per_attention_head * config.num_attention_heads, diff --git a/src/mcore_bridge/model/mm_gpts/gemma4.py b/src/mcore_bridge/model/mm_gpts/gemma4.py index e3ef62d..6c267ce 100644 --- a/src/mcore_bridge/model/mm_gpts/gemma4.py +++ b/src/mcore_bridge/model/mm_gpts/gemma4.py @@ -122,8 +122,10 @@ def __init__( use_double_wide_mlp = text_config.use_double_wide_mlp and is_kv_shared_layer ffn_hidden_size = config.ffn_hidden_size config.ffn_hidden_size = config.ffn_hidden_size * (2 if use_double_wide_mlp else 1) - super().__init__(config, submodules, *args, **kwargs) - config.ffn_hidden_size = ffn_hidden_size + try: + super().__init__(config, submodules, *args, **kwargs) + finally: + config.ffn_hidden_size = ffn_hidden_size class Gemma4Bridge(MultimodalGPTBridge): diff --git a/src/mcore_bridge/model/modules/gated_delta_net.py b/src/mcore_bridge/model/modules/gated_delta_net.py index ef150b7..5fcb167 100644 --- a/src/mcore_bridge/model/modules/gated_delta_net.py +++ b/src/mcore_bridge/model/modules/gated_delta_net.py @@ -96,10 +96,13 @@ def __init__(self, config: ModelConfig, submodules: 'GatedDeltaNetSubmodules', * submodules.in_proj = IdentityOp if 'cp_comm_type' not in inspect.signature(_GatedDeltaNet).parameters: kwargs.pop('cp_comm_type', None) - super().__init__(config, submodules, *args, **kwargs) + try: + super().__init__(config, submodules, *args, **kwargs) + finally: + if config.linear_decoupled_in_proj: + submodules.in_proj = in_proj if not config.linear_decoupled_in_proj: return - submodules.in_proj = in_proj self.in_proj_qkvz_dim = self.qk_dim * 2 + self.v_dim * 2 self.in_proj_ba_dim = self.num_value_heads * 2 del self.in_proj diff --git a/src/mcore_bridge/model/modules/mtp_layer.py b/src/mcore_bridge/model/modules/mtp_layer.py index 537abf3..8be6aeb 100644 --- a/src/mcore_bridge/model/modules/mtp_layer.py +++ b/src/mcore_bridge/model/modules/mtp_layer.py @@ -29,11 +29,14 @@ def __init__(self, config: ModelConfig, submodules, *args, **kwargs): if config.fp8_param: eh_proj = submodules.eh_proj submodules.eh_proj = IdentityOp - super().__init__(config, submodules, *args, **kwargs) + try: + super().__init__(config, submodules, *args, **kwargs) + finally: + if config.fp8_param: + submodules.eh_proj = eh_proj self.tp_group = getattr(self, 'tp_group', None) if not config.fp8_param: return - submodules.eh_proj = eh_proj fp8_context = transformer_engine.pytorch.fp8_model_init(enabled=False) with fp8_context: self.eh_proj = build_module( From 8da05dfc646c04cb857471b53b7e42a17470020e Mon Sep 17 00:00:00 2001 From: Jintao Huang Date: Sat, 9 May 2026 10:58:44 +0800 Subject: [PATCH 16/26] fix --- src/mcore_bridge/model/mm_gpts/gemma4.py | 30 +++++++++++++++++++++++- 1 file changed, 29 insertions(+), 1 deletion(-) diff --git a/src/mcore_bridge/model/mm_gpts/gemma4.py b/src/mcore_bridge/model/mm_gpts/gemma4.py index 6c267ce..e14716d 100644 --- a/src/mcore_bridge/model/mm_gpts/gemma4.py +++ b/src/mcore_bridge/model/mm_gpts/gemma4.py @@ -1,7 +1,9 @@ # Copyright (c) ModelScope Contributors. All rights reserved. import copy +import torch from megatron.core.models.gpt.gpt_layer_specs import get_gpt_decoder_block_spec from megatron.core.transformer.attention import SelfAttention, SelfAttentionSubmodules +from megatron.core.transformer.identity_op import IdentityOp from megatron.core.transformer.mlp import MLP from transformers import AutoModel, PretrainedConfig from typing import Optional @@ -18,6 +20,20 @@ from ..module import CustomTransformerLayer +class Gemma4VNorm(torch.nn.Module): + """RMSNorm without learnable scale, mirroring HF `Gemma4RMSNorm(with_scale=False)`.""" + + def __init__(self, dim: int, eps: float = 1e-6): + super().__init__() + self.eps = eps + + def forward(self, x: torch.Tensor) -> torch.Tensor: + orig_dtype = x.dtype + x = x.to(torch.float32) + variance = x.pow(2).mean(-1, keepdim=True) + return (x * torch.rsqrt(variance + self.eps)).to(orig_dtype) + + class Gemma4Vit(HuggingFaceVit): module_mapping = { 'model.vision_tower': 'vision_tower', @@ -92,16 +108,28 @@ def __init__( self.layer_type in prev_layers and layer_idx == len(prev_layers) - 1 - prev_layers[::-1].index(self.layer_type)) - # Patch config so the underlying linear_qkv is built with the correct shapes + # Patch config so the underlying linear_qkv/q_layernorm/k_layernorm are built correctly. + # HF keeps `q_norm` on every layer, but only builds `k_norm`/`v_norm` on non-kv-shared + # layers, so replace `k_layernorm` with `IdentityOp` when this layer shares KV. orig_kv_channels = config.kv_channels orig_num_query_groups = config.num_query_groups + orig_k_layernorm = submodules.k_layernorm config.kv_channels = self.head_dim config.num_query_groups = num_key_value_heads + if self.is_kv_shared_layer: + submodules.k_layernorm = IdentityOp try: super().__init__(config, submodules, layer_number, *args, **kwargs) finally: config.kv_channels = orig_kv_channels config.num_query_groups = orig_num_query_groups + submodules.k_layernorm = orig_k_layernorm + + # HF builds a `v_norm` (RMSNorm without learnable scale) for non-kv-shared layers. + # mcore's SelfAttention has no v_layernorm by default, so attach one explicitly here. + self.v_norm = ( + Gemma4VNorm(self.head_dim, eps=self.config.layernorm_epsilon) + if not self.is_kv_shared_layer else None) class Gemma4MLP(MLP): From d1eff8a9f4e25bec42e770813fe47eb3a4b787a5 Mon Sep 17 00:00:00 2001 From: Jintao Huang Date: Sat, 9 May 2026 11:00:09 +0800 Subject: [PATCH 17/26] fix --- src/mcore_bridge/config/parser.py | 2 -- src/mcore_bridge/model/mm_gpts/gemma4.py | 2 +- 2 files changed, 1 insertion(+), 3 deletions(-) diff --git a/src/mcore_bridge/config/parser.py b/src/mcore_bridge/config/parser.py index 68ef8e3..877b63c 100644 --- a/src/mcore_bridge/config/parser.py +++ b/src/mcore_bridge/config/parser.py @@ -149,8 +149,6 @@ def hf_to_mcore_config(hf_config: PretrainedConfig) -> Dict[str, Any]: n_shared_experts = res.pop('n_shared_experts') elif llm_model_type in {'ernie4_5', 'ernie4_5_moe', 'glm4'}: res['rotary_interleaved'] = True - elif hf_model_type == 'gemma4': - config.qk_layernorm = True elif llm_model_type == 'gpt_oss': res['add_bias_linear'] = True res['bias_dropout_fusion'] = False diff --git a/src/mcore_bridge/model/mm_gpts/gemma4.py b/src/mcore_bridge/model/mm_gpts/gemma4.py index e3ef62d..e7dcace 100644 --- a/src/mcore_bridge/model/mm_gpts/gemma4.py +++ b/src/mcore_bridge/model/mm_gpts/gemma4.py @@ -15,7 +15,7 @@ from ..register import ModelLoader, ModelMeta, register_model from ..rope import get_rope_inv_freq from .utils import HuggingFaceVit -from ..module import CustomTransformerLayer +from ..modules import CustomTransformerLayer class Gemma4Vit(HuggingFaceVit): From 0c22e68aec252d8d4a1cad9ef79fbc3d27e3de79 Mon Sep 17 00:00:00 2001 From: Jintao Huang Date: Sat, 9 May 2026 11:02:07 +0800 Subject: [PATCH 18/26] fix --- src/mcore_bridge/model/mm_gpts/gemma4.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/src/mcore_bridge/model/mm_gpts/gemma4.py b/src/mcore_bridge/model/mm_gpts/gemma4.py index e14716d..be0a17f 100644 --- a/src/mcore_bridge/model/mm_gpts/gemma4.py +++ b/src/mcore_bridge/model/mm_gpts/gemma4.py @@ -89,8 +89,6 @@ def __init__( if self.use_alternative_attention else text_config.num_key_value_heads) self.num_key_value_groups = text_config.num_attention_heads // num_key_value_heads - self.is_causal = getattr(text_config, 'use_bidirectional_attention', None) != 'all' - # Shared KV across the trailing layers num_kv_shared_layers = getattr(text_config, 'num_kv_shared_layers', 0) first_kv_shared_layer_idx = text_config.num_hidden_layers - num_kv_shared_layers From fa5360be7b1a8624b66d0c6d232acab958577408 Mon Sep 17 00:00:00 2001 From: Jintao Huang Date: Sat, 9 May 2026 11:13:45 +0800 Subject: [PATCH 19/26] fix --- src/mcore_bridge/config/parser.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/src/mcore_bridge/config/parser.py b/src/mcore_bridge/config/parser.py index 877b63c..21d34c5 100644 --- a/src/mcore_bridge/config/parser.py +++ b/src/mcore_bridge/config/parser.py @@ -149,6 +149,8 @@ def hf_to_mcore_config(hf_config: PretrainedConfig) -> Dict[str, Any]: n_shared_experts = res.pop('n_shared_experts') elif llm_model_type in {'ernie4_5', 'ernie4_5_moe', 'glm4'}: res['rotary_interleaved'] = True + elif hf_model_type in {'gemma4'}: + res['qk_layernorm'] = True elif llm_model_type == 'gpt_oss': res['add_bias_linear'] = True res['bias_dropout_fusion'] = False From d25db28302db3cdcb2ad8aa6a75dd9cbfa03cece Mon Sep 17 00:00:00 2001 From: Jintao Huang Date: Sat, 9 May 2026 11:14:39 +0800 Subject: [PATCH 20/26] update --- src/mcore_bridge/bridge/gpt_bridge.py | 9 +++++++-- src/mcore_bridge/model/mm_gpts/gemma4.py | 21 ++++++++++++--------- 2 files changed, 19 insertions(+), 11 deletions(-) diff --git a/src/mcore_bridge/bridge/gpt_bridge.py b/src/mcore_bridge/bridge/gpt_bridge.py index c380328..62f6cb7 100644 --- a/src/mcore_bridge/bridge/gpt_bridge.py +++ b/src/mcore_bridge/bridge/gpt_bridge.py @@ -439,7 +439,8 @@ def _set_state_dict(self, to_mcore: bool, *, offset: float = 0, - is_expert: bool = False): + is_expert: bool = False, + _check_mg_param: bool = True): if '.' in mg_key: module_key, param_key = mg_key.rsplit('.', 1) else: @@ -487,7 +488,11 @@ def _set_state_dict(self, else: mg_param = deep_getattr(sub_module, param_key) if to_mcore: - assert mg_param is not None, f'mg_module: {mg_module}, mg_key: {mg_key}' + if mg_param is None: + if _check_mg_param: + raise ValueError(f'mg_module: {mg_module}, mg_key: {mg_key}') + else: + return hf_weight = hf_state_dict[hf_key].load() if module_key in { 'embedding.word_embeddings', 'output_layer' diff --git a/src/mcore_bridge/model/mm_gpts/gemma4.py b/src/mcore_bridge/model/mm_gpts/gemma4.py index a41fdf0..aa4b7fb 100644 --- a/src/mcore_bridge/model/mm_gpts/gemma4.py +++ b/src/mcore_bridge/model/mm_gpts/gemma4.py @@ -14,10 +14,10 @@ from ..constant import ModelType from ..gpt_model import GPTModel from ..mm_gpt_model import MultimodalGPTModel +from ..modules import CustomTransformerLayer from ..register import ModelLoader, ModelMeta, register_model from ..rope import get_rope_inv_freq from .utils import HuggingFaceVit -from ..modules import CustomTransformerLayer class Gemma4VNorm(torch.nn.Module): @@ -82,8 +82,7 @@ def __init__( if not self.is_sliding and text_config.global_head_dim else text_config.head_dim) # Alternative attention (k == v) for global layers when `attention_k_eq_v` is set - self.use_alternative_attention = ( - getattr(text_config, 'attention_k_eq_v', False) and not self.is_sliding) + self.use_alternative_attention = (getattr(text_config, 'attention_k_eq_v', False) and not self.is_sliding) num_key_value_heads = ( text_config.num_global_key_value_heads if self.use_alternative_attention else text_config.num_key_value_heads) @@ -96,8 +95,7 @@ def __init__( prev_layers = text_config.layer_types[:first_kv_shared_layer_idx] if self.is_kv_shared_layer: # For shared layers, reuse KV from the last non-shared layer of the same type - self.kv_shared_layer_index = ( - len(prev_layers) - 1 - prev_layers[::-1].index(self.layer_type)) + self.kv_shared_layer_index = (len(prev_layers) - 1 - prev_layers[::-1].index(self.layer_type)) self.store_full_length_kv = False else: self.kv_shared_layer_index = None @@ -126,8 +124,7 @@ def __init__( # HF builds a `v_norm` (RMSNorm without learnable scale) for non-kv-shared layers. # mcore's SelfAttention has no v_layernorm by default, so attach one explicitly here. self.v_norm = ( - Gemma4VNorm(self.head_dim, eps=self.config.layernorm_epsilon) - if not self.is_kv_shared_layer else None) + Gemma4VNorm(self.head_dim, eps=self.config.layernorm_epsilon) if not self.is_kv_shared_layer else None) class Gemma4MLP(MLP): @@ -155,7 +152,12 @@ def __init__( class Gemma4Bridge(MultimodalGPTBridge): - pass + + def _set_qk_layernorm(self, mg_attn, hf_state_dict, to_mcore): + self._set_state_dict( + mg_attn, 'q_layernorm.weight', hf_state_dict, self.hf_q_norm_key, to_mcore, _check_mg_param=False) + self._set_state_dict( + mg_attn, 'k_layernorm.weight', hf_state_dict, self.hf_k_norm_key, to_mcore, _check_mg_param=False) class Gemma4TextGPTModel(GPTModel): @@ -183,6 +185,7 @@ def _set_inv_freq(self): self.config.rope_scaling = rope_scaling + class Gemma4TransformerLayer(CustomTransformerLayer): pass @@ -202,11 +205,11 @@ def get_transformer_layer_spec(self, vp_stage: Optional[int] = None): layer_spec.submodules.mlp.module = Gemma4MLP return layer_specs - def _set_custom_layer(self, transformer_layer_spec): for layer_spec in transformer_layer_spec.layer_specs: layer_spec.module = Gemma4TransformerLayer + register_model( ModelMeta( ModelType.gemma4, From bfbcbc4b2747e1fd14eba117460f887ac14d4fe2 Mon Sep 17 00:00:00 2001 From: Jintao Huang Date: Sat, 9 May 2026 11:44:16 +0800 Subject: [PATCH 21/26] update --- .../model/modules/transformer_layer.py | 66 ++++++++++--------- 1 file changed, 34 insertions(+), 32 deletions(-) diff --git a/src/mcore_bridge/model/modules/transformer_layer.py b/src/mcore_bridge/model/modules/transformer_layer.py index be07631..74fbf59 100644 --- a/src/mcore_bridge/model/modules/transformer_layer.py +++ b/src/mcore_bridge/model/modules/transformer_layer.py @@ -1,11 +1,14 @@ import enum import inspect import torch +from megatron.core.extensions.transformer_engine import TEFusedMLP from megatron.core.process_groups_config import ProcessGroupCollection from megatron.core.tensor_parallel.mappings import (gather_from_sequence_parallel_region, scatter_to_sequence_parallel_region) from megatron.core.transformer.identity_op import IdentityOp from megatron.core.transformer.mlp import MLP +from megatron.core.transformer.moe.experts import SequentialMLP, TEGroupedMLP +from megatron.core.transformer.moe.moe_layer import MoELayer from megatron.core.transformer.spec_utils import ModuleSpec, build_module from megatron.core.transformer.transformer_config import TransformerConfig from megatron.core.transformer.transformer_layer import (TransformerLayer, TransformerLayerSubmodules, @@ -126,43 +129,13 @@ def __init__( hidden_size=self.config.hidden_size, eps=self.config.layernorm_epsilon, ) - # [Module 8: MLP block] - additional_mlp_kwargs = {} - # import here to avoid circular import - from megatron.core.extensions.transformer_engine import TEFusedMLP - from megatron.core.transformer.moe.experts import SequentialMLP, TEGroupedMLP - from megatron.core.transformer.moe.moe_layer import MoELayer - - from mcore_bridge.model.gpts.glm4 import Glm4MLP - from mcore_bridge.model.mm_gpts.gemma4 import Gemma4MLP - # MLP expects tp_group but MoELayer expects pg_collection to be passed in. - # We can change MLP to accept pg_collection but it makes the logic implicit - # The conditional below is to make the logic explicit - # if submodules.mlp is not a ModuleSpec,we dont have to handle passing additional kwargs - if isinstance(submodules.mlp, ModuleSpec): - if submodules.mlp.module in (MoELayer, TEGroupedMLP, SequentialMLP): - additional_mlp_kwargs['pg_collection'] = pg_collection - # Pass is_mtp_layer flag to MoELayer to distinguish MTP MoE layers. - if submodules.mlp.module == MoELayer and 'is_mtp_layer' in inspect.signature(MoELayer).parameters: - additional_mlp_kwargs['is_mtp_layer'] = self.is_mtp_layer - elif submodules.mlp.module in (MLP, Glm4MLP): - assert hasattr(pg_collection, 'tp'), 'TP process group is required for MLP in TransformerLayer' - additional_mlp_kwargs['tp_group'] = pg_collection.tp - elif submodules.mlp.module == Gemma4MLP: - additional_mlp_kwargs['layer_number'] = layer_number - elif TEFusedMLP is not None and submodules.mlp.module == TEFusedMLP: - assert hasattr(pg_collection, 'tp'), 'TP process group is required for TEFusedMLP in TransformerLayer' - additional_mlp_kwargs['tp_group'] = pg_collection.tp - else: - logger.warning_once(f'Unknown MLP type: {submodules.mlp.module}. Using default kwargs.') - self.mlp = build_module(submodules.mlp, config=self.config, **additional_mlp_kwargs) + # [Module 8: MLP block] + self.mlp = self._build_mlp(submodules.mlp) if hasattr(self.mlp, 'set_layer_number'): self.mlp.set_layer_number(self.layer_number) - # [Module 9: BiasDropoutFusion] self.mlp_bda = build_module(submodules.mlp_bda) - self.is_moe_layer = isinstance(self.mlp, MoELayer) self.recompute_input_layernorm = False @@ -234,6 +207,35 @@ def can_recompute_pre_mlp_layernorm_for_cudagraph(): # self.bias_dropout_add_exec_handler = nullcontext if use_nvfuser else torch.enable_grad self.bias_dropout_add_exec_handler = torch.enable_grad + def _build_mlp(self, mlp_spec): + pg_collection = self.pg_collection + additional_mlp_kwargs = {} + # import here to avoid circular import + from mcore_bridge.model.gpts.glm4 import Glm4MLP + from mcore_bridge.model.mm_gpts.gemma4 import Gemma4MLP + + # MLP expects tp_group but MoELayer expects pg_collection to be passed in. + # We can change MLP to accept pg_collection but it makes the logic implicit + # The conditional below is to make the logic explicit + # if smlp_spec is not a ModuleSpec,we dont have to handle passing additional kwargs + if isinstance(mlp_spec, ModuleSpec): + if mlp_spec.module in (MoELayer, TEGroupedMLP, SequentialMLP): + additional_mlp_kwargs['pg_collection'] = pg_collection + # Pass is_mtp_layer flag to MoELayer to distinguish MTP MoE layers. + if mlp_spec.module == MoELayer and 'is_mtp_layer' in inspect.signature(MoELayer).parameters: + additional_mlp_kwargs['is_mtp_layer'] = self.is_mtp_layer + elif mlp_spec.module in (MLP, Glm4MLP): + assert hasattr(pg_collection, 'tp'), 'TP process group is required for MLP in TransformerLayer' + additional_mlp_kwargs['tp_group'] = pg_collection.tp + elif mlp_spec.module == Gemma4MLP: + additional_mlp_kwargs['layer_number'] = self.layer_number + elif TEFusedMLP is not None and mlp_spec.module == TEFusedMLP: + assert hasattr(pg_collection, 'tp'), 'TP process group is required for TEFusedMLP in TransformerLayer' + additional_mlp_kwargs['tp_group'] = pg_collection.tp + else: + logger.warning_once(f'Unknown MLP type: {mlp_spec.module}. Using default kwargs.') + self.mlp = build_module(mlp_spec, config=self.config, **additional_mlp_kwargs) + def forward(self, *args, **kwargs): """ Perform a forward pass through the transformer layer. From 2300825168648b45820217bfa5941afb49ed4bac Mon Sep 17 00:00:00 2001 From: Jintao Huang Date: Sat, 9 May 2026 13:13:52 +0800 Subject: [PATCH 22/26] fix --- src/mcore_bridge/bridge/gpt_bridge.py | 11 +++++++---- src/mcore_bridge/model/modules/transformer_layer.py | 2 +- 2 files changed, 8 insertions(+), 5 deletions(-) diff --git a/src/mcore_bridge/bridge/gpt_bridge.py b/src/mcore_bridge/bridge/gpt_bridge.py index 62f6cb7..837f5f9 100644 --- a/src/mcore_bridge/bridge/gpt_bridge.py +++ b/src/mcore_bridge/bridge/gpt_bridge.py @@ -40,9 +40,12 @@ class GPTBridge: hf_o_proj_key = 'o_proj' hf_attn_prefix = 'self_attn' hf_mlp_prefix = 'mlp' + hf_post_attention_layernorm = 'post_attention_layernorm' hf_gate_key = 'gate.weight' hf_shared_expert_key = None hf_expert_bias_key = 'gate.e_score_correction_bias' + additional_dim0_keys = {} + additional_dim1_keys = {} def __init__(self, config: ModelConfig): self.config = config @@ -124,11 +127,11 @@ def _get_tp_split_dim(self, mg_key: Optional[str]) -> Optional[int]: 'linear_kv_up_proj', # mtp 'eh_proj', - } + } & self.additional_dim0_keys if self.config.task_type in {'causal_lm', 'generative_reranker'}: dim0_keys.add('output_layer') # RowLinear - dim1_keys = {'out_proj', 'linear_proj', 'linear_fc2'} + dim1_keys = {'out_proj', 'linear_proj', 'linear_fc2'} & self.additional_dim1_keys if 'lora_A' not in mg_key and 'lora_B' not in mg_key: key, suffix = mg_key.rsplit('.', 2)[-2:] if suffix == 'layer_norm_weight': @@ -1592,13 +1595,13 @@ def _set_layer_mlp(self, mg_layer, hf_state_dict, layer_idx: int, to_mcore: bool hf_state_dict.update( self._set_moe_state( mg_mlp, hf_state_dict, f'{self.hf_mlp_prefix}.', layer_idx, to_mcore, is_mtp=is_mtp)) - self._set_state_dict(mg_layer, 'pre_mlp_layernorm.weight', hf_state_dict, 'post_attention_layernorm.weight', + self._set_state_dict(mg_layer, 'pre_mlp_layernorm.weight', hf_state_dict, f'{self.hf_post_attention_layernorm}.weight', to_mcore) else: hf_state_dict.update( self._set_mlp_state(mg_mlp, hf_state_dict, f'{self.hf_mlp_prefix}.', layer_idx, to_mcore)) self._set_state_dict(mg_layer, 'mlp.linear_fc1.layer_norm_weight', hf_state_dict, - 'post_attention_layernorm.weight', to_mcore) + f'{self.hf_post_attention_layernorm}.weight', to_mcore) return hf_state_dict def _set_layer_state(self, mg_layer, hf_state_dict, hf_prefix: str, layer_idx: int, to_mcore: bool): diff --git a/src/mcore_bridge/model/modules/transformer_layer.py b/src/mcore_bridge/model/modules/transformer_layer.py index 74fbf59..b3bd54b 100644 --- a/src/mcore_bridge/model/modules/transformer_layer.py +++ b/src/mcore_bridge/model/modules/transformer_layer.py @@ -234,7 +234,7 @@ def _build_mlp(self, mlp_spec): additional_mlp_kwargs['tp_group'] = pg_collection.tp else: logger.warning_once(f'Unknown MLP type: {mlp_spec.module}. Using default kwargs.') - self.mlp = build_module(mlp_spec, config=self.config, **additional_mlp_kwargs) + return build_module(mlp_spec, config=self.config, **additional_mlp_kwargs) def forward(self, *args, **kwargs): """ From cda31a5117211c8361f904a719767c879b71c100 Mon Sep 17 00:00:00 2001 From: Jintao Huang Date: Sat, 9 May 2026 14:30:39 +0800 Subject: [PATCH 23/26] update --- src/mcore_bridge/bridge/gpt_bridge.py | 11 +- src/mcore_bridge/model/mm_gpts/gemma4.py | 189 ++++++++++++++++++++++- 2 files changed, 194 insertions(+), 6 deletions(-) diff --git a/src/mcore_bridge/bridge/gpt_bridge.py b/src/mcore_bridge/bridge/gpt_bridge.py index 837f5f9..2c410a2 100644 --- a/src/mcore_bridge/bridge/gpt_bridge.py +++ b/src/mcore_bridge/bridge/gpt_bridge.py @@ -1595,8 +1595,8 @@ def _set_layer_mlp(self, mg_layer, hf_state_dict, layer_idx: int, to_mcore: bool hf_state_dict.update( self._set_moe_state( mg_mlp, hf_state_dict, f'{self.hf_mlp_prefix}.', layer_idx, to_mcore, is_mtp=is_mtp)) - self._set_state_dict(mg_layer, 'pre_mlp_layernorm.weight', hf_state_dict, f'{self.hf_post_attention_layernorm}.weight', - to_mcore) + self._set_state_dict(mg_layer, 'pre_mlp_layernorm.weight', hf_state_dict, + f'{self.hf_post_attention_layernorm}.weight', to_mcore) else: hf_state_dict.update( self._set_mlp_state(mg_mlp, hf_state_dict, f'{self.hf_mlp_prefix}.', layer_idx, to_mcore)) @@ -1618,13 +1618,16 @@ def _set_layer_state(self, mg_layer, hf_state_dict, hf_prefix: str, layer_idx: i hf_state_dict = self._add_prefix(hf_state_dict, hf_prefix) return hf_state_dict + def _set_word_embeddings(self, mg_model, hf_state_dict, to_mcore): + lm_model = getattr(mg_model, 'language_model') if self.is_multimodal else mg_model + self._set_state_dict(lm_model, 'embedding.word_embeddings.weight', hf_state_dict, self.hf_embed_key, to_mcore) + def _convert_pre_process(self, mg_model, hf_state_dict, hf_prefix: str, to_mcore): if to_mcore: hf_state_dict = self._remove_prefix(hf_state_dict, hf_prefix) else: hf_state_dict = {} - lm_model = getattr(mg_model, 'language_model') if self.is_multimodal else mg_model - self._set_state_dict(lm_model, 'embedding.word_embeddings.weight', hf_state_dict, self.hf_embed_key, to_mcore) + self._set_word_embeddings(mg_model, hf_state_dict, to_mcore) if self.is_multimodal: for prefix, mg_prefix in self.module_mapping.items(): mg_module = deep_getattr(mg_model, f'visual.{mg_prefix}') diff --git a/src/mcore_bridge/model/mm_gpts/gemma4.py b/src/mcore_bridge/model/mm_gpts/gemma4.py index aa4b7fb..7ba0faf 100644 --- a/src/mcore_bridge/model/mm_gpts/gemma4.py +++ b/src/mcore_bridge/model/mm_gpts/gemma4.py @@ -1,10 +1,14 @@ # Copyright (c) ModelScope Contributors. All rights reserved. import copy +import math import torch +from megatron.core.extensions.transformer_engine import TEColumnParallelLinear, TENorm, TERowParallelLinear from megatron.core.models.gpt.gpt_layer_specs import get_gpt_decoder_block_spec +from megatron.core.tensor_parallel import VocabParallelEmbedding from megatron.core.transformer.attention import SelfAttention, SelfAttentionSubmodules from megatron.core.transformer.identity_op import IdentityOp from megatron.core.transformer.mlp import MLP +from megatron.core.transformer.spec_utils import build_module from transformers import AutoModel, PretrainedConfig from typing import Optional @@ -121,6 +125,25 @@ def __init__( config.num_query_groups = orig_num_query_groups submodules.k_layernorm = orig_k_layernorm + # HF kv-shared layers only keep `q_proj` (K/V are reused from an earlier layer), so the + # default mcore `linear_qkv` shape `[Q + 2*KV, hidden]` over-allocates. Rebuild it with + # out_dim = query_projection_size so shapes match HF `q_proj` 1:1 for weight bridging. + # Mirrors attention.py L1275-L1289, minus the `+ 2 * kv_projection_size` term. + if self.is_kv_shared_layer: + self.linear_qkv_out_dim = self.query_projection_size + self.linear_qkv = submodules.linear_qkv( + self.config.hidden_size, + self.linear_qkv_out_dim, + config=self.config, + init_method=self.config.init_method, + gather_output=False, + bias=self.config.add_bias_linear or self.config.add_qkv_bias, + skip_bias_add=False, + is_expert=False, + tp_comm_buffer_name='qkv', + tp_group=self.pg_collection.tp, + ) + # HF builds a `v_norm` (RMSNorm without learnable scale) for non-kv-shared layers. # mcore's SelfAttention has no v_layernorm by default, so attach one explicitly here. self.v_norm = ( @@ -152,6 +175,9 @@ def __init__( class Gemma4Bridge(MultimodalGPTBridge): + hf_post_attention_layernorm = 'pre_feedforward_layernorm' + additional_dim0_keys = {'per_layer_input_gate', 'per_layer_model_projection'} + additional_dim1_keys = {'per_layer_projection'} def _set_qk_layernorm(self, mg_attn, hf_state_dict, to_mcore): self._set_state_dict( @@ -159,12 +185,110 @@ def _set_qk_layernorm(self, mg_attn, hf_state_dict, to_mcore): self._set_state_dict( mg_attn, 'k_layernorm.weight', hf_state_dict, self.hf_k_norm_key, to_mcore, _check_mg_param=False) + def _set_qkv(self, mg_attn, hf_state_dict, to_mcore: bool): + is_kv_shared_layer = False if mg_attn is None else mg_attn.is_kv_shared_layer + is_kv_shared_layer = torch.tensor([is_kv_shared_layer], dtype=torch.bool, device='cuda') + if self.pp_size > 1: + dist.all_reduce(is_lora, group=self.pp_group, op=dist.ReduceOp.MAX) + is_kv_shared_layer = is_kv_shared_layer.item() + if is_kv_shared_layer: + self._set_state_dict(mg_attn, 'linear_qkv.weight', hf_state_dict, 'q_proj.weight', to_mcore) + return hf_state_dict + else: + return super()._set_qkv(mg_attn, hf_state_dict, to_mcore) + + def _set_layer_state(self, mg_layer, hf_state_dict, hf_prefix: str, layer_idx: int, to_mcore: bool): + hf_prefix = f'{hf_prefix}{layer_idx}.' + if to_mcore: + hf_state_dict = self._remove_prefix(hf_state_dict, hf_prefix) + else: + hf_state_dict = {} + hf_state_dict.update(self._set_layer_attn(mg_layer, hf_state_dict, layer_idx, to_mcore)) + hf_state_dict.update(self._set_layer_mlp(mg_layer, hf_state_dict, layer_idx, to_mcore)) + for key in [ + 'post_attention_layernorm', 'post_feedforward_layernorm', 'per_layer_input_gate', + 'per_layer_projection', 'post_per_layer_input_norm' + ]: + self._set_state_dict( + mg_layer, + f'{key}.weight', + hf_state_dict if to_mcore else new_hf_state_dict, + f'{key}.weight', + to_mcore, + _check_mg_param=False) + if to_mcore: + hf_state_dict = {} + else: + hf_state_dict = self._add_prefix(hf_state_dict, hf_prefix) + return hf_state_dict + + def _set_word_embeddings(self, mg_model, hf_state_dict, to_mcore): + lm_model = getattr(mg_model, 'language_model') if self.is_multimodal else mg_model + self._set_state_dict(lm_model, 'embedding.word_embeddings.weight', hf_state_dict, self.hf_embed_key, to_mcore) + for key in ['embed_tokens_per_layer', 'per_layer_model_projection', 'per_layer_projection_norm']: + self._set_state_dict(lm_model, f'{key}.weight', hf_state_dict, f'model.language_model.{key}.weight', + to_mcore) + class Gemma4TextGPTModel(GPTModel): def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) - print() + text_config = self.config.hf_config.text_config + # HF: `self.unique_layer_types = set(self.config.layer_types)` — needed by the rotary + # embedding selection logic (sliding vs global) when that path is wired up. + self.unique_layer_types = set(text_config.layer_types) + + # HF: Per-Layer Embeddings (PLE). Only populated on the pre-process (PP stage 0) side, + # since the auxiliary signal is derived from `input_ids` / the token embedding output. + # See `modeling_gemma4.py` L1574-L1592 for the reference construction. Built with + # megatron-native parallel modules (mirroring `LanguageModelEmbedding` at + # `gpt_model.py` L150-L157) so the aux signal follows the TP/SP layout of the + # primary embedding. + self.hidden_size_per_layer_input = getattr(text_config, 'hidden_size_per_layer_input', None) + if self.hidden_size_per_layer_input and self.pre_process: + num_layers = text_config.num_hidden_layers + hidden_size = text_config.hidden_size + total_dim = num_layers * self.hidden_size_per_layer_input + tp_size = self.config.tensor_model_parallel_size + # Pad aux vocab size to be TP-divisible, matching how `GPTModel` pads the main + # `padded_vocab_size` before feeding it into `VocabParallelEmbedding`. + padded_vocab_size_per_layer = math.ceil(text_config.vocab_size_per_layer_input / tp_size) * tp_size + # Vocab-parallel embedding (shard on vocab dim). HF's `Gemma4TextScaledWordEmbedding` + # applies an `embed_scale = hidden_size_per_layer_input**0.5` factor on forward; + # we capture the scale as a sibling attribute so the weight shape stays 1:1 with HF. + self.embed_tokens_per_layer = VocabParallelEmbedding( + num_embeddings=padded_vocab_size_per_layer, + embedding_dim=total_dim, + init_method=self.config.init_method, + config=self.config, + tp_group=self.pg_collection.tp, + ) + self.embed_tokens_per_layer_scale = self.hidden_size_per_layer_input**0.5 + self.per_layer_input_scale = 2.0**-0.5 + # Column-parallel projection: output dim `num_layers * hidden_size_per_layer_input` + # is split across TP ranks so each rank produces its own shard of the packed + # per-layer input tensor. + self.per_layer_model_projection = build_module( + TEColumnParallelLinear, + hidden_size, + total_dim, + config=self.config, + init_method=self.config.init_method, + gather_output=False, + bias=False, + skip_bias_add=True, + is_expert=False, + tp_comm_buffer_name='per_layer_model_projection', + tp_group=self.pg_collection.tp, + ) + self.per_layer_model_projection_scale = hidden_size**-0.5 + self.per_layer_projection_norm = build_module( + TENorm, + hidden_size=self.hidden_size_per_layer_input, + config=self.config, + eps=self.config.layernorm_epsilon, + ) def _set_inv_freq(self): rope_scaling = self.config.rope_scaling @@ -187,7 +311,68 @@ def _set_inv_freq(self): class Gemma4TransformerLayer(CustomTransformerLayer): - pass + + def __init__(self, config, submodules, *args, **kwargs): + super().__init__(config, submodules, *args, **kwargs) + text_config = config.hf_config.text_config + hidden_size = self.config.hidden_size + eps = self.config.layernorm_epsilon + + # HF keeps an extra layernorm after self-attn / feedforward (before the residual add). + # mcore's TransformerLayer does not include these, so attach them here. + self.post_attention_layernorm = build_module(TENorm, hidden_size=hidden_size, config=self.config, eps=eps) + self.post_feedforward_layernorm = build_module(TENorm, hidden_size=hidden_size, config=self.config, eps=eps) + + # HF: `self.register_buffer("layer_scalar", torch.ones(1))` + self.register_buffer('layer_scalar', torch.ones(1)) + + # HF: per-layer input projection branch, only when `hidden_size_per_layer_input` is set. + self.hidden_size_per_layer_input = getattr(text_config, 'hidden_size_per_layer_input', None) + if self.hidden_size_per_layer_input: + from transformers.activations import ACT2FN + self.act_fn = ACT2FN[text_config.hidden_activation] + # Megatron-style parallel linears (see attention.py L348-361 for `linear_proj`): + # `per_layer_input_gate` is column-parallel (output dim split across TP), then its + # output is consumed by the row-parallel `per_layer_projection` which gathers along TP. + self.per_layer_input_gate = build_module( + TEColumnParallelLinear, + hidden_size, + self.hidden_size_per_layer_input, + config=self.config, + init_method=self.config.init_method, + gather_output=False, + bias=False, + skip_bias_add=True, + is_expert=False, + tp_comm_buffer_name='per_layer_input_gate', + tp_group=self.pg_collection.tp, + ) + self.per_layer_projection = build_module( + TERowParallelLinear, + self.hidden_size_per_layer_input, + hidden_size, + config=self.config, + init_method=self.config.output_layer_init_method, + bias=False, + input_is_parallel=True, + skip_bias_add=True, + is_expert=False, + tp_comm_buffer_name='per_layer_projection', + tp_group=self.pg_collection.tp, + ) + self.post_per_layer_input_norm = build_module(TENorm, hidden_size=hidden_size, config=self.config, eps=eps) + + # HF: extra layernorms when the layer runs a MoE block in parallel with the dense MLP. + # Router / experts modules are gemma4-specific and intentionally skipped here; they can + # be wired by the bridge/forward override once their mcore counterparts are implemented. + self.enable_moe_block = getattr(text_config, 'enable_moe_block', False) + if self.enable_moe_block: + self.post_feedforward_layernorm_1 = build_module( + TENorm, hidden_size=hidden_size, config=self.config, eps=eps) + self.post_feedforward_layernorm_2 = build_module( + TENorm, hidden_size=hidden_size, config=self.config, eps=eps) + self.pre_feedforward_layernorm_2 = build_module( + TENorm, hidden_size=hidden_size, config=self.config, eps=eps) class Gemma4GPTModel(MultimodalGPTModel): From 7e6fb75b421009316e07009590b2788c70e36d03 Mon Sep 17 00:00:00 2001 From: Jintao Huang Date: Sat, 9 May 2026 16:52:11 +0800 Subject: [PATCH 24/26] update --- src/mcore_bridge/model/mm_gpts/gemma4.py | 54 +++++++++++++++++++++++- 1 file changed, 52 insertions(+), 2 deletions(-) diff --git a/src/mcore_bridge/model/mm_gpts/gemma4.py b/src/mcore_bridge/model/mm_gpts/gemma4.py index 7ba0faf..54dc50d 100644 --- a/src/mcore_bridge/model/mm_gpts/gemma4.py +++ b/src/mcore_bridge/model/mm_gpts/gemma4.py @@ -47,18 +47,65 @@ class Gemma4Vit(HuggingFaceVit): } _vision_tower = ['vision_tower', 'audio_tower'] _aligner = ['embed_vision', 'embed_audio'] - support_multimodal = False def prepare_model(self, hf_config: PretrainedConfig): - from transformers.models.gemma4.modeling_gemma4 import Gemma4MultimodalEmbedder + from transformers.models.gemma4.modeling_gemma4 import Gemma4MultimodalEmbedder, Gemma4Model self.vision_tower = AutoModel.from_config(hf_config.vision_config) self.audio_tower = AutoModel.from_config(hf_config.audio_config) if hf_config.audio_config is not None else None self.embed_vision = Gemma4MultimodalEmbedder(hf_config.vision_config, hf_config.text_config) self.embed_audio = ( Gemma4MultimodalEmbedder(hf_config.audio_config, hf_config.text_config) if hf_config.audio_config is not None else None) + self.register_buffer("embed_scale", torch.tensor(hf_config.hidden_size**0.5), persistent=False) + self.model_cls = Gemma4Model def get_inputs_embeds(self, inputs_embeds, **kwargs): + input_ids = kwargs.get('input_ids') + inputs_embeds *= self.embed_scale.to(inputs_embeds.dtype) + + hf_config = self.hf_config + input_ids = kwargs.get('input_ids') + pixel_values = kwargs.get('pixel_values') + pixel_values_videos = kwargs.get('pixel_values_videos') + input_features = kwargs.get('input_features') + input_features_mask = kwargs.get('input_features_mask') + image_position_ids = kwargs.get('image_position_ids') + video_position_ids = kwargs.get('video_position_ids') + + image_mask = input_ids == hf_config.image_token_id + video_mask = input_ids == hf_config.video_token_id + audio_mask = input_ids == hf_config.audio_token_id + + if pixel_values is not None: + vision_outputs = self.vision_tower( + pixel_values=pixel_values.to(self.vision_tower.dtype), + pixel_position_ids=image_position_ids, + ) + image_features = self.embed_vision(inputs_embeds=vision_outputs.last_hidden_state) + image_features = image_features.to(inputs_embeds.device, inputs_embeds.dtype) + image_mask_e = image_mask.unsqueeze(-1).expand_as(inputs_embeds).to(inputs_embeds.device) + inputs_embeds = inputs_embeds.masked_scatter(image_mask_e, image_features) + + if pixel_values_videos is not None: + pixel_values_videos_flat = pixel_values_videos.flatten(0, 1) + video_position_ids_flat = video_position_ids.flatten(0, 1) if video_position_ids is not None else None + vision_outputs = self.vision_tower( + pixel_values=pixel_values_videos_flat.to(self.vision_tower.dtype), + pixel_position_ids=video_position_ids_flat, + ) + video_features = self.embed_vision(inputs_embeds=vision_outputs.last_hidden_state) + video_features = video_features.to(inputs_embeds.device, inputs_embeds.dtype) + video_mask_e = video_mask.unsqueeze(-1).expand_as(inputs_embeds).to(inputs_embeds.device) + inputs_embeds = inputs_embeds.masked_scatter(video_mask_e, video_features) + + if (input_features is not None and input_features_mask is not None and self.audio_tower is not None): + audio_outputs = self.audio_tower(input_features, input_features_mask, return_dict=True) + audio_features = self.embed_audio(inputs_embeds=audio_outputs.last_hidden_state) + audio_features = audio_features[audio_outputs.attention_mask] + audio_features = audio_features.to(inputs_embeds.device, inputs_embeds.dtype) + audio_mask_e = audio_mask.unsqueeze(-1).expand_as(inputs_embeds).to(inputs_embeds.device) + inputs_embeds = inputs_embeds.masked_scatter(audio_mask_e, audio_features) + return inputs_embeds @@ -309,6 +356,9 @@ def _set_inv_freq(self): self.config.rope_scaling = rope_scaling + def forward(self): + pass + class Gemma4TransformerLayer(CustomTransformerLayer): From e1d085192d83801dd6e1348d1ee2bdc19821a545 Mon Sep 17 00:00:00 2001 From: Jintao Huang Date: Sat, 9 May 2026 16:52:22 +0800 Subject: [PATCH 25/26] update --- src/mcore_bridge/bridge/gpt_bridge.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/mcore_bridge/bridge/gpt_bridge.py b/src/mcore_bridge/bridge/gpt_bridge.py index 2c410a2..a1899d7 100644 --- a/src/mcore_bridge/bridge/gpt_bridge.py +++ b/src/mcore_bridge/bridge/gpt_bridge.py @@ -44,8 +44,8 @@ class GPTBridge: hf_gate_key = 'gate.weight' hf_shared_expert_key = None hf_expert_bias_key = 'gate.e_score_correction_bias' - additional_dim0_keys = {} - additional_dim1_keys = {} + additional_dim0_keys = set() + additional_dim1_keys = set() def __init__(self, config: ModelConfig): self.config = config From e3cbe5db1bfdca03c1c35fd833a9b5924bb88237 Mon Sep 17 00:00:00 2001 From: Jintao Huang Date: Sat, 9 May 2026 18:13:19 +0800 Subject: [PATCH 26/26] update --- src/mcore_bridge/model/mm_gpts/gemma4.py | 101 ++++++++--------------- 1 file changed, 36 insertions(+), 65 deletions(-) diff --git a/src/mcore_bridge/model/mm_gpts/gemma4.py b/src/mcore_bridge/model/mm_gpts/gemma4.py index 54dc50d..3b92bbf 100644 --- a/src/mcore_bridge/model/mm_gpts/gemma4.py +++ b/src/mcore_bridge/model/mm_gpts/gemma4.py @@ -2,6 +2,7 @@ import copy import math import torch +import torch.distributed as dist from megatron.core.extensions.transformer_engine import TEColumnParallelLinear, TENorm, TERowParallelLinear from megatron.core.models.gpt.gpt_layer_specs import get_gpt_decoder_block_spec from megatron.core.tensor_parallel import VocabParallelEmbedding @@ -49,14 +50,15 @@ class Gemma4Vit(HuggingFaceVit): _aligner = ['embed_vision', 'embed_audio'] def prepare_model(self, hf_config: PretrainedConfig): - from transformers.models.gemma4.modeling_gemma4 import Gemma4MultimodalEmbedder, Gemma4Model + from transformers.models.gemma4.modeling_gemma4 import Gemma4Model, Gemma4MultimodalEmbedder self.vision_tower = AutoModel.from_config(hf_config.vision_config) + dtype = self.vision_tower.dtype self.audio_tower = AutoModel.from_config(hf_config.audio_config) if hf_config.audio_config is not None else None - self.embed_vision = Gemma4MultimodalEmbedder(hf_config.vision_config, hf_config.text_config) + self.embed_vision = Gemma4MultimodalEmbedder(hf_config.vision_config, hf_config.text_config).to(dtype) self.embed_audio = ( - Gemma4MultimodalEmbedder(hf_config.audio_config, hf_config.text_config) + Gemma4MultimodalEmbedder(hf_config.audio_config, hf_config.text_config).to(dtype) if hf_config.audio_config is not None else None) - self.register_buffer("embed_scale", torch.tensor(hf_config.hidden_size**0.5), persistent=False) + self.register_buffer('embed_scale', torch.tensor(hf_config.hidden_size**0.5).to(dtype), persistent=False) self.model_cls = Gemma4Model def get_inputs_embeds(self, inputs_embeds, **kwargs): @@ -75,38 +77,35 @@ def get_inputs_embeds(self, inputs_embeds, **kwargs): image_mask = input_ids == hf_config.image_token_id video_mask = input_ids == hf_config.video_token_id audio_mask = input_ids == hf_config.audio_token_id + multimodal_mask = image_mask | video_mask | audio_mask + llm_input_ids = input_ids.clone() + llm_input_ids[multimodal_mask] = hf_config.text_config.pad_token_id if pixel_values is not None: - vision_outputs = self.vision_tower( - pixel_values=pixel_values.to(self.vision_tower.dtype), - pixel_position_ids=image_position_ids, - ) - image_features = self.embed_vision(inputs_embeds=vision_outputs.last_hidden_state) + with self.patch_hf_config(): + image_features = self.model_cls.get_image_features( + self, pixel_values, image_position_ids, return_dict=True).pooler_output image_features = image_features.to(inputs_embeds.device, inputs_embeds.dtype) image_mask_e = image_mask.unsqueeze(-1).expand_as(inputs_embeds).to(inputs_embeds.device) inputs_embeds = inputs_embeds.masked_scatter(image_mask_e, image_features) if pixel_values_videos is not None: - pixel_values_videos_flat = pixel_values_videos.flatten(0, 1) - video_position_ids_flat = video_position_ids.flatten(0, 1) if video_position_ids is not None else None - vision_outputs = self.vision_tower( - pixel_values=pixel_values_videos_flat.to(self.vision_tower.dtype), - pixel_position_ids=video_position_ids_flat, - ) - video_features = self.embed_vision(inputs_embeds=vision_outputs.last_hidden_state) + with self.patch_hf_config(): + video_features = self.get_video_features( + pixel_values_videos, video_position_ids, return_dict=True).pooler_output video_features = video_features.to(inputs_embeds.device, inputs_embeds.dtype) video_mask_e = video_mask.unsqueeze(-1).expand_as(inputs_embeds).to(inputs_embeds.device) inputs_embeds = inputs_embeds.masked_scatter(video_mask_e, video_features) if (input_features is not None and input_features_mask is not None and self.audio_tower is not None): - audio_outputs = self.audio_tower(input_features, input_features_mask, return_dict=True) - audio_features = self.embed_audio(inputs_embeds=audio_outputs.last_hidden_state) - audio_features = audio_features[audio_outputs.attention_mask] + with self.patch_hf_config(): + audio_output = self.get_audio_features(input_features, input_features_mask, return_dict=True) + audio_features = audio_output.pooler_output + audio_features = audio_features[audio_output.attention_mask] audio_features = audio_features.to(inputs_embeds.device, inputs_embeds.dtype) audio_mask_e = audio_mask.unsqueeze(-1).expand_as(inputs_embeds).to(inputs_embeds.device) inputs_embeds = inputs_embeds.masked_scatter(audio_mask_e, audio_features) - - return inputs_embeds + return {'inputs_embeds': inputs_embeds, 'llm_input_ids': llm_input_ids} class Gemma4SelfAttention(SelfAttention): @@ -155,9 +154,6 @@ def __init__( self.layer_type in prev_layers and layer_idx == len(prev_layers) - 1 - prev_layers[::-1].index(self.layer_type)) - # Patch config so the underlying linear_qkv/q_layernorm/k_layernorm are built correctly. - # HF keeps `q_norm` on every layer, but only builds `k_norm`/`v_norm` on non-kv-shared - # layers, so replace `k_layernorm` with `IdentityOp` when this layer shares KV. orig_kv_channels = config.kv_channels orig_num_query_groups = config.num_query_groups orig_k_layernorm = submodules.k_layernorm @@ -172,10 +168,6 @@ def __init__( config.num_query_groups = orig_num_query_groups submodules.k_layernorm = orig_k_layernorm - # HF kv-shared layers only keep `q_proj` (K/V are reused from an earlier layer), so the - # default mcore `linear_qkv` shape `[Q + 2*KV, hidden]` over-allocates. Rebuild it with - # out_dim = query_projection_size so shapes match HF `q_proj` 1:1 for weight bridging. - # Mirrors attention.py L1275-L1289, minus the `+ 2 * kv_projection_size` term. if self.is_kv_shared_layer: self.linear_qkv_out_dim = self.query_projection_size self.linear_qkv = submodules.linear_qkv( @@ -191,8 +183,6 @@ def __init__( tp_group=self.pg_collection.tp, ) - # HF builds a `v_norm` (RMSNorm without learnable scale) for non-kv-shared layers. - # mcore's SelfAttention has no v_layernorm by default, so attach one explicitly here. self.v_norm = ( Gemma4VNorm(self.head_dim, eps=self.config.layernorm_epsilon) if not self.is_kv_shared_layer else None) @@ -236,7 +226,7 @@ def _set_qkv(self, mg_attn, hf_state_dict, to_mcore: bool): is_kv_shared_layer = False if mg_attn is None else mg_attn.is_kv_shared_layer is_kv_shared_layer = torch.tensor([is_kv_shared_layer], dtype=torch.bool, device='cuda') if self.pp_size > 1: - dist.all_reduce(is_lora, group=self.pp_group, op=dist.ReduceOp.MAX) + dist.all_reduce(is_kv_shared_layer, group=self.pp_group, op=dist.ReduceOp.MAX) is_kv_shared_layer = is_kv_shared_layer.item() if is_kv_shared_layer: self._set_state_dict(mg_attn, 'linear_qkv.weight', hf_state_dict, 'q_proj.weight', to_mcore) @@ -257,12 +247,7 @@ def _set_layer_state(self, mg_layer, hf_state_dict, hf_prefix: str, layer_idx: i 'per_layer_projection', 'post_per_layer_input_norm' ]: self._set_state_dict( - mg_layer, - f'{key}.weight', - hf_state_dict if to_mcore else new_hf_state_dict, - f'{key}.weight', - to_mcore, - _check_mg_param=False) + mg_layer, f'{key}.weight', hf_state_dict, f'{key}.weight', to_mcore, _check_mg_param=False) if to_mcore: hf_state_dict = {} else: @@ -281,29 +266,18 @@ class Gemma4TextGPTModel(GPTModel): def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) + self.pad_embedding = self.embedding.word_embeddings.weight text_config = self.config.hf_config.text_config - # HF: `self.unique_layer_types = set(self.config.layer_types)` — needed by the rotary - # embedding selection logic (sliding vs global) when that path is wired up. + self.text_config = text_config self.unique_layer_types = set(text_config.layer_types) - # HF: Per-Layer Embeddings (PLE). Only populated on the pre-process (PP stage 0) side, - # since the auxiliary signal is derived from `input_ids` / the token embedding output. - # See `modeling_gemma4.py` L1574-L1592 for the reference construction. Built with - # megatron-native parallel modules (mirroring `LanguageModelEmbedding` at - # `gpt_model.py` L150-L157) so the aux signal follows the TP/SP layout of the - # primary embedding. self.hidden_size_per_layer_input = getattr(text_config, 'hidden_size_per_layer_input', None) if self.hidden_size_per_layer_input and self.pre_process: num_layers = text_config.num_hidden_layers hidden_size = text_config.hidden_size total_dim = num_layers * self.hidden_size_per_layer_input tp_size = self.config.tensor_model_parallel_size - # Pad aux vocab size to be TP-divisible, matching how `GPTModel` pads the main - # `padded_vocab_size` before feeding it into `VocabParallelEmbedding`. padded_vocab_size_per_layer = math.ceil(text_config.vocab_size_per_layer_input / tp_size) * tp_size - # Vocab-parallel embedding (shard on vocab dim). HF's `Gemma4TextScaledWordEmbedding` - # applies an `embed_scale = hidden_size_per_layer_input**0.5` factor on forward; - # we capture the scale as a sibling attribute so the weight shape stays 1:1 with HF. self.embed_tokens_per_layer = VocabParallelEmbedding( num_embeddings=padded_vocab_size_per_layer, embedding_dim=total_dim, @@ -313,9 +287,6 @@ def __init__(self, *args, **kwargs): ) self.embed_tokens_per_layer_scale = self.hidden_size_per_layer_input**0.5 self.per_layer_input_scale = 2.0**-0.5 - # Column-parallel projection: output dim `num_layers * hidden_size_per_layer_input` - # is split across TP ranks so each rank produces its own shard of the packed - # per-layer input tensor. self.per_layer_model_projection = build_module( TEColumnParallelLinear, hidden_size, @@ -356,8 +327,18 @@ def _set_inv_freq(self): self.config.rope_scaling = rope_scaling - def forward(self): - pass + def forward(self, *args, **kwargs): + extra_block_kwargs = kwargs.pop('extra_block_kwargs', None) or {} + llm_input_ids = extra_block_kwargs.pop('llm_input_ids', None) + if self.hidden_size_per_layer_input and self.pre_process: + per_layer_inputs = (self.embed_tokens_per_layer(llm_input_ids) * self.embed_tokens_per_layer_scale).reshape( + *llm_input_ids.shape, + self.text_config.num_hidden_layers, + self.hidden_size_per_layer_input, + ) + extra_block_kwargs['per_layer_inputs'] = per_layer_inputs + kwargs['extra_block_kwargs'] = extra_block_kwargs + return super().forward(*args, **kwargs) class Gemma4TransformerLayer(CustomTransformerLayer): @@ -368,22 +349,15 @@ def __init__(self, config, submodules, *args, **kwargs): hidden_size = self.config.hidden_size eps = self.config.layernorm_epsilon - # HF keeps an extra layernorm after self-attn / feedforward (before the residual add). - # mcore's TransformerLayer does not include these, so attach them here. self.post_attention_layernorm = build_module(TENorm, hidden_size=hidden_size, config=self.config, eps=eps) self.post_feedforward_layernorm = build_module(TENorm, hidden_size=hidden_size, config=self.config, eps=eps) - # HF: `self.register_buffer("layer_scalar", torch.ones(1))` self.register_buffer('layer_scalar', torch.ones(1)) - # HF: per-layer input projection branch, only when `hidden_size_per_layer_input` is set. self.hidden_size_per_layer_input = getattr(text_config, 'hidden_size_per_layer_input', None) if self.hidden_size_per_layer_input: from transformers.activations import ACT2FN self.act_fn = ACT2FN[text_config.hidden_activation] - # Megatron-style parallel linears (see attention.py L348-361 for `linear_proj`): - # `per_layer_input_gate` is column-parallel (output dim split across TP), then its - # output is consumed by the row-parallel `per_layer_projection` which gathers along TP. self.per_layer_input_gate = build_module( TEColumnParallelLinear, hidden_size, @@ -412,9 +386,6 @@ def __init__(self, config, submodules, *args, **kwargs): ) self.post_per_layer_input_norm = build_module(TENorm, hidden_size=hidden_size, config=self.config, eps=eps) - # HF: extra layernorms when the layer runs a MoE block in parallel with the dense MLP. - # Router / experts modules are gemma4-specific and intentionally skipped here; they can - # be wired by the bridge/forward override once their mcore counterparts are implemented. self.enable_moe_block = getattr(text_config, 'enable_moe_block', False) if self.enable_moe_block: self.post_feedforward_layernorm_1 = build_module(