Skip to content
Open
Show file tree
Hide file tree
Changes from 5 commits
Commits
Show all changes
36 commits
Select commit Hold shift + click to select a range
4e2df75
update
Jintao-Huang Apr 29, 2026
c388954
update
Jintao-Huang Apr 29, 2026
76af2bc
update
Jintao-Huang Apr 29, 2026
54e3343
update
Jintao-Huang Apr 29, 2026
25a45bd
update
Jintao-Huang Apr 30, 2026
3210617
update
Jintao-Huang Apr 30, 2026
e1355dd
Merge branch 'main' into support_gemma4
Jintao-Huang May 3, 2026
5b4e118
update
Jintao-Huang May 4, 2026
d1d2246
update
Jintao-Huang May 4, 2026
24cd697
Merge branch 'main' into support_gemma4
Jintao-Huang May 4, 2026
bba8144
merge
Jintao-Huang May 4, 2026
14b1644
fix
Jintao-Huang May 5, 2026
196a58f
update
Jintao-Huang May 5, 2026
68e33a7
update
Jintao-Huang May 5, 2026
9736c3e
Merge branch 'main' into support_gemma4
Jintao-Huang May 5, 2026
44ddaec
update
Jintao-Huang May 5, 2026
b3cc043
Merge branch 'main' into support_gemma4
Jintao-Huang May 5, 2026
7563bc4
Merge branch 'main' into support_gemma4
Jintao-Huang May 8, 2026
4a74289
Merge branch 'main' into support_gemma4
Jintao-Huang May 8, 2026
0de0ebb
update
Jintao-Huang May 9, 2026
2a81bf0
update
Jintao-Huang May 9, 2026
7e05d3d
update
Jintao-Huang May 9, 2026
8da05df
fix
Jintao-Huang May 9, 2026
d1eff8a
fix
Jintao-Huang May 9, 2026
e545c4f
Merge remote-tracking branch 'refs/remotes/origin/support_gemma4' int…
Jintao-Huang May 9, 2026
0c22e68
fix
Jintao-Huang May 9, 2026
63511bd
Merge remote-tracking branch 'refs/remotes/origin/support_gemma4' int…
Jintao-Huang May 9, 2026
fa5360b
fix
Jintao-Huang May 9, 2026
d25db28
update
Jintao-Huang May 9, 2026
bfbcbc4
update
Jintao-Huang May 9, 2026
2300825
fix
Jintao-Huang May 9, 2026
cda31a5
update
Jintao-Huang May 9, 2026
7e6fb75
update
Jintao-Huang May 9, 2026
e1d0851
update
Jintao-Huang May 9, 2026
0178948
Merge branch 'main' into support_gemma4
Jintao-Huang May 9, 2026
e3cbe5d
update
Jintao-Huang May 9, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 3 additions & 2 deletions src/mcore_bridge/config/model_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,13 +3,13 @@
import os
import re
import torch.nn.functional as F
from dataclasses import dataclass
from dataclasses import dataclass, field
from megatron.core import mpu
from megatron.core.transformer import TransformerConfig
from transformers import PretrainedConfig
from transformers.utils import is_torch_npu_available
from transformers.utils.versions import require_version
from typing import List, Literal, Optional, Union
from typing import Any, Dict, List, Literal, Optional, Union

from mcore_bridge.utils import get_logger, json_parse_to_dict

Expand Down Expand Up @@ -229,6 +229,7 @@ class ModelConfig(TransformerConfig):
task_type: Literal['causal_lm', 'seq_cls', 'embedding', 'generative_reranker'] = 'causal_lm'
num_labels: Optional[int] = None
mlp_padding_free: bool = False
model_kwargs: Dict[str, Any] = field(default_factory=dict)

_mindspeed_defaults_cache = None

Expand Down
1 change: 1 addition & 0 deletions src/mcore_bridge/model/constant.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ class MLLMModelType:
glm4v_moe = 'glm4v_moe'
kimi_vl = 'kimi_vl'
llama4 = 'llama4'
gemma4 = 'gemma4'

kimi_k25 = 'kimi_k25'

Expand Down
65 changes: 37 additions & 28 deletions src/mcore_bridge/model/gpt_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,9 +110,7 @@ def __init__(
for i in range(len(self.decoder.layers)):
if hasattr(self.decoder.layers[i].self_attention, 'rotary_pos_emb'):
del self.decoder.layers[i].self_attention.rotary_pos_emb
self.attention_scaling = 1.
new_inv_freq, self.attention_scaling = get_rope_inv_freq(config)
self.rotary_pos_emb.inv_freq = new_inv_freq.to(self.rotary_pos_emb.inv_freq.device)
self._set_inv_freq()
if self.config.task_type == 'seq_cls' and self.post_process:
self.output_layer = OutputLayerLinear(
config.hidden_size,
Expand Down Expand Up @@ -222,7 +220,36 @@ def _preprocess(
if decoder_input is not None and self.training and torch.is_grad_enabled() and not decoder_input.requires_grad:
# fix LoRA incompatibility with gradient checkpointing
decoder_input = decoder_input.requires_grad_(True)
rotary_pos_emb, decoder_rotary_pos_emb, rotary_pos_cos, rotary_pos_sin = self._get_rotary_pos_emb(
decoder_input, position_ids, packed_seq_params=packed_seq_params)
Comment on lines +213 to +214
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The inference_context is not passed to the _get_rotary_pos_emb method. This will cause the method to skip critical inference-specific logic, such as utilizing the RoPE cache or correctly calculating the rotary sequence length for flash decoding, which can lead to performance degradation or incorrect results during inference.

Suggested change
rotary_pos_emb, decoder_rotary_pos_emb, rotary_pos_cos, rotary_pos_sin = self._get_rotary_pos_emb(
decoder_input, position_ids, packed_seq_params=packed_seq_params)
rotary_pos_emb, decoder_rotary_pos_emb, rotary_pos_cos, rotary_pos_sin = self._get_rotary_pos_emb(
decoder_input, position_ids, packed_seq_params=packed_seq_params, inference_context=inference_context)


if (in_inference_mode and ((self.config.enable_cuda_graph and self.config.cuda_graph_scope != 'full_iteration')
or self.config.flash_decode) and rotary_pos_cos is not None
and inference_context.is_static_batching()):
current_batch_size = input_ids.shape[0]
sequence_len_offset = torch.tensor(
[inference_context.sequence_len_offset] * current_batch_size,
dtype=torch.int32,
device=rotary_pos_cos.device, # Co-locate this with the rotary tensors
)
else:
sequence_len_offset = None

# Wrap decoder_input to allow the decoder (TransformerBlock) to delete the
# reference held by this caller function, enabling early garbage collection for
# inference. Skip wrapping if decoder_input is logged after decoder completion.
if in_inference_mode and not has_config_logger_enabled(self.config):
decoder_input = WrappedTensor(decoder_input)

return (decoder_input, rotary_pos_emb, decoder_rotary_pos_emb, rotary_pos_cos, rotary_pos_sin,
sequence_len_offset)

def _set_inv_freq(self):
self.attention_scaling = 1.
new_inv_freq, self.attention_scaling = get_rope_inv_freq(self.config)
self.rotary_pos_emb.inv_freq = new_inv_freq.to(self.rotary_pos_emb.inv_freq.device)

def _get_rotary_pos_emb(self, decoder_input, position_ids, packed_seq_params, inference_context=None):
# Rotary positional embeddings (embedding is None for PP intermediate devices)
rotary_pos_emb = None
rotary_pos_cos = None
Expand Down Expand Up @@ -257,26 +284,13 @@ def _preprocess(
rotary_seq_len,
packed_seq=packed_seq,
)
decoder_rotary_pos_emb = rotary_pos_emb
packed_seq = packed_seq_params is not None and packed_seq_params.qkv_format == 'thd'
if self.position_embedding_type == 'rope' and packed_seq and not self.config.apply_rope_fusion:
assert position_ids.shape[0] == 1, f'position_ids.shape: {position_ids.shape}'
decoder_rotary_pos_emb = rotary_pos_emb[position_ids[0]]

if (in_inference_mode and ((self.config.enable_cuda_graph and self.config.cuda_graph_scope != 'full_iteration')
or self.config.flash_decode) and rotary_pos_cos is not None
and inference_context.is_static_batching()):
current_batch_size = input_ids.shape[0]
sequence_len_offset = torch.tensor(
[inference_context.sequence_len_offset] * current_batch_size,
dtype=torch.int32,
device=rotary_pos_cos.device, # Co-locate this with the rotary tensors
)
else:
sequence_len_offset = None

# Wrap decoder_input to allow the decoder (TransformerBlock) to delete the
# reference held by this caller function, enabling early garbage collection for
# inference. Skip wrapping if decoder_input is logged after decoder completion.
if in_inference_mode and not has_config_logger_enabled(self.config):
decoder_input = WrappedTensor(decoder_input)

return decoder_input, rotary_pos_emb, rotary_pos_cos, rotary_pos_sin, sequence_len_offset
return rotary_pos_emb, decoder_rotary_pos_emb, rotary_pos_cos, rotary_pos_sin

# Code borrowed from NVIDIA/Megatron-LM
def forward(
Expand Down Expand Up @@ -308,19 +322,14 @@ def forward(

inference_context = deprecate_inference_params(inference_context, inference_params)

decoder_input, rotary_pos_emb, rotary_pos_cos, rotary_pos_sin, sequence_len_offset = (
decoder_input, rotary_pos_emb, decoder_rotary_pos_emb, rotary_pos_cos, rotary_pos_sin, sequence_len_offset = (
self._preprocess(
input_ids=input_ids,
position_ids=position_ids,
decoder_input=decoder_input,
inference_context=inference_context,
packed_seq_params=packed_seq_params,
))
decoder_rotary_pos_emb = rotary_pos_emb
packed_seq = packed_seq_params is not None and packed_seq_params.qkv_format == 'thd'
if self.position_embedding_type == 'rope' and packed_seq and not self.config.apply_rope_fusion:
assert position_ids.shape[0] == 1, f'position_ids.shape: {position_ids.shape}'
decoder_rotary_pos_emb = rotary_pos_emb[position_ids[0]]

mtp_decoder_input = decoder_input
if self.config.is_multimodal and self.config.mtp_num_layers and decoder_input is None:
Expand Down
4 changes: 3 additions & 1 deletion src/mcore_bridge/model/mm_gpt_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@


class MultimodalGPTModel(MegatronModule):
language_model_cls = GPTModel

def __init__(self,
config: ModelConfig,
Expand All @@ -29,7 +30,8 @@ def __init__(self,
super().__init__(config)
self.pre_process = pre_process
self.post_process = post_process
self.language_model = GPTModel(config, transformer_layer_spec, pre_process, post_process, *_args, **kwargs)
self.language_model = self.language_model_cls(config, transformer_layer_spec, pre_process, post_process, *_args,
**kwargs)
self.vp_stage = self.language_model.vp_stage
self.share_embeddings_and_output_weights = self.language_model.share_embeddings_and_output_weights
self.model_meta = config.model_meta
Expand Down
2 changes: 1 addition & 1 deletion src/mcore_bridge/model/mm_gpts/__init__.py
Original file line number Diff line number Diff line change
@@ -1,2 +1,2 @@
# Copyright (c) ModelScope Contributors. All rights reserved.
from . import glm, internvl, kimi_vl, llama4, llava, qwen, qwen3_5, qwen3_5_gdn, qwen3_omni, qwen3_vl
from . import gemma4, glm, internvl, kimi_vl, llama4, llava, qwen, qwen3_5, qwen3_5_gdn, qwen3_omni, qwen3_vl
86 changes: 86 additions & 0 deletions src/mcore_bridge/model/mm_gpts/gemma4.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,86 @@
# Copyright (c) ModelScope Contributors. All rights reserved.
import copy
from transformers import AutoModel, PretrainedConfig

from mcore_bridge.bridge import GPTBridge

from ..constant import ModelType
from ..gpt_model import GPTModel
from ..mm_gpt_model import MultimodalGPTModel
from ..register import ModelLoader, ModelMeta, register_model
from ..rope import get_rope_inv_freq
from .utils import HuggingFaceVit


class Gemma4Vit(HuggingFaceVit):
module_mapping = {
'model.vision_tower': 'vision_tower',
'model.embed_vision': 'embed_vision',
'model.audio_tower': 'audio_tower',
'model.embed_audio': 'embed_audio',
}
_vision_tower = ['vision_tower', 'audio_tower']
_aligner = ['embed_vision', 'embed_audio']
support_multimodal = False

def prepare_model(self, hf_config: PretrainedConfig):
from transformers.models.gemma4.modeling_gemma4 import Gemma4MultimodalEmbedder
self.vision_tower = AutoModel.from_config(hf_config.vision_config)
self.audio_tower = AutoModel.from_config(hf_config.audio_config) if hf_config.audio_config is not None else None
self.embed_vision = Gemma4MultimodalEmbedder(hf_config.vision_config, hf_config.text_config)
self.embed_audio = (
Gemma4MultimodalEmbedder(hf_config.audio_config, hf_config.text_config)
if hf_config.audio_config is not None else None)

def get_inputs_embeds(self, inputs_embeds, **kwargs):
return inputs_embeds


class Gemma4Bridge(GPTBridge):
pass


class Gemma4TextGPTModel(GPTModel):

def _set_inv_freq(self):
rope_scaling = self.config.rope_scaling
self.config.rope_scaling = rope_scaling['sliding_attention']
new_inv_freq, attention_scaling = get_rope_inv_freq(self.config)
assert attention_scaling == 1, 'not support'
self.rotary_pos_emb.inv_freq = new_inv_freq.to(self.rotary_pos_emb.inv_freq.device)
# full
self.full_rotary_pos_emb = copy.copy(self.rotary_pos_emb)
self.config.rope_scaling = rope_scaling['full_attention']
kwargs = {}
if self.config.rope_scaling['rope_type'] == 'proportional':
kwargs['head_dim_key'] = 'global_head_dim'
new_inv_freq, attention_scaling = get_rope_inv_freq(self.config, **kwargs)
assert attention_scaling == 1, 'not support'
self.full_rotary_pos_emb.inv_freq = new_inv_freq
self.attention_scaling = attention_scaling

self.config.rope_scaling = rope_scaling
Comment on lines +311 to +328
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The implementation of _set_inv_freq for Gemma4TextGPTModel has several issues:

  1. Potential Runtime Crash: Restoring self.config.rope_scaling to the original nested dictionary at line 62 will cause a KeyError in _get_rope_type (called via dynamic_rope_update during every forward pass) because that function expects a dictionary with a rope_type key at the top level, which the Gemma4 configuration lacks (it uses sliding_attention and full_attention as top-level keys).
  2. Dead Code: self.full_rotary_pos_emb is initialized but never utilized by the base GPTModel forward pass or RoPE application logic.
  3. Poor Error Messages: The assertion messages 'not support' at lines 49 and 58 are not descriptive. They should clearly state that attention scaling other than 1.0 is not supported for this model.



class Gemma4GPTModel(MultimodalGPTModel):
language_model_cls = Gemma4TextGPTModel


class Gemma4Loader(ModelLoader):
model_cls = Gemma4GPTModel
# def get_transformer_layer_spec(self, vp_stage: Optional[int] = None):
# layer_specs = get_gpt_decoder_block_spec(
# self.config, use_transformer_engine=True, normalization=self.config.normalization, vp_stage=vp_stage)
# for layer_spec in layer_specs.layer_specs:
# pass
# return layer_specs
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

Please remove the commented-out code in Gemma4Loader to maintain code cleanliness and readability.



register_model(
ModelMeta(
ModelType.gemma4,
['gemma4'],
bridge_cls=Gemma4Bridge,
visual_cls=Gemma4Vit,
loader=Gemma4Loader,
))
4 changes: 2 additions & 2 deletions src/mcore_bridge/model/rope.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,12 +106,12 @@ def _get_rope_type(rope_scaling: Optional[Dict[str, Any]]):
return rope_type


def get_rope_inv_freq(config, seq_len=None):
def get_rope_inv_freq(config, seq_len=None, **kwargs):
from transformers.modeling_rope_utils import ROPE_INIT_FUNCTIONS
ROPE_INIT_FUNCTIONS.update(EXTENDED_ROPE_INIT_FUNCTIONS)
dummy_config = _get_dummy_config(config)
rope_init_fn = ROPE_INIT_FUNCTIONS[_get_rope_type(config.rope_scaling)]
inv_freq, attention_scaling = rope_init_fn(dummy_config, 'cpu', seq_len=seq_len)
inv_freq, attention_scaling = rope_init_fn(dummy_config, 'cpu', seq_len=seq_len, **kwargs)
if attention_scaling is None:
attention_scaling = 1.
return inv_freq, attention_scaling
Expand Down
39 changes: 3 additions & 36 deletions src/mcore_bridge/tuners/patcher.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,4 @@
# Copyright (c) ModelScope Contributors. All rights reserved.
import copy
from contextlib import contextmanager
from megatron.core.extensions.transformer_engine import TEGroupedLinear, TELayerNormColumnParallelLinear, TELinear
from megatron.core.transformer.module import MegatronModule
from megatron.core.transformer.moe.router import TopKRouter
Expand All @@ -11,6 +9,8 @@
from torch import nn
from typing import Optional

from mcore_bridge.utils import patch_deepcopy

from .lora import LoraParallelLinear


Expand All @@ -37,47 +37,14 @@ def dispatch_megatron(
model.dispatch_megatron = dispatch_megatron


@contextmanager
def _patch_deepcopy():
_origin_deepcopy = copy.deepcopy
copy_keys = ('tp_group', '_tp_group', 'config')

def new_deepcopy(x, *args, **kwargs):
if not isinstance(x, nn.Module):
return _origin_deepcopy(x, *args, **kwargs)

saved_values = {}
for key in copy_keys:
val = getattr(x, key, None)
if val is not None:
saved_values[key] = val
setattr(x, key, None)

try:
res = _origin_deepcopy(x, *args, **kwargs)
finally:
for key, value in saved_values.items():
setattr(x, key, value)

for key, value in saved_values.items():
setattr(res, key, value)
return res

copy.deepcopy = new_deepcopy
try:
yield
finally:
copy.deepcopy = _origin_deepcopy


def _patch_lora_model():
if hasattr(LoraModel, '_mcore_patched'):
return

__origin_init__ = LoraModel.__init__

def __new_init__(self, *args, **kwargs):
with _patch_deepcopy():
with patch_deepcopy():
__origin_init__(self, *args, **kwargs)
if not isinstance(self.model, MegatronModule):
return
Expand Down
2 changes: 1 addition & 1 deletion src/mcore_bridge/utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,4 +6,4 @@
from .megatron_utils import get_local_layer_specs, set_random_seed, split_cp_inputs, unwrap_model
from .safetensors import SafetensorLazyLoader, StreamingSafetensorSaver
from .torch_utils import gc_collect, get_current_device, safe_ddp_context, to_device
from .utils import deep_getattr, get_env_args, json_parse_to_dict
from .utils import deep_getattr, get_env_args, json_parse_to_dict, patch_deepcopy
36 changes: 36 additions & 0 deletions src/mcore_bridge/utils/utils.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,8 @@
import copy
import json
import os
from contextlib import contextmanager
from torch import nn
from transformers.utils import strtobool
from typing import Callable, Dict, Optional, TypeVar, Union

Expand Down Expand Up @@ -58,3 +61,36 @@ def deep_getattr(obj, attr: str, default=None):
else:
obj = getattr(obj, a, default)
return obj


@contextmanager
def patch_deepcopy():
_origin_deepcopy = copy.deepcopy
copy_keys = ('tp_group', '_tp_group', 'config')

def new_deepcopy(x, *args, **kwargs):
if not isinstance(x, nn.Module):
return _origin_deepcopy(x, *args, **kwargs)

saved_values = {}
for key in copy_keys:
val = getattr(x, key, None)
if val is not None:
saved_values[key] = val
setattr(x, key, None)

try:
res = _origin_deepcopy(x, *args, **kwargs)
finally:
for key, value in saved_values.items():
setattr(x, key, value)

for key, value in saved_values.items():
setattr(res, key, value)
return res

copy.deepcopy = new_deepcopy
try:
yield
finally:
copy.deepcopy = _origin_deepcopy
Loading