Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
36 commits
Select commit Hold shift + click to select a range
4e2df75
update
Jintao-Huang Apr 29, 2026
c388954
update
Jintao-Huang Apr 29, 2026
76af2bc
update
Jintao-Huang Apr 29, 2026
54e3343
update
Jintao-Huang Apr 29, 2026
25a45bd
update
Jintao-Huang Apr 30, 2026
3210617
update
Jintao-Huang Apr 30, 2026
e1355dd
Merge branch 'main' into support_gemma4
Jintao-Huang May 3, 2026
5b4e118
update
Jintao-Huang May 4, 2026
d1d2246
update
Jintao-Huang May 4, 2026
24cd697
Merge branch 'main' into support_gemma4
Jintao-Huang May 4, 2026
bba8144
merge
Jintao-Huang May 4, 2026
14b1644
fix
Jintao-Huang May 5, 2026
196a58f
update
Jintao-Huang May 5, 2026
68e33a7
update
Jintao-Huang May 5, 2026
9736c3e
Merge branch 'main' into support_gemma4
Jintao-Huang May 5, 2026
44ddaec
update
Jintao-Huang May 5, 2026
b3cc043
Merge branch 'main' into support_gemma4
Jintao-Huang May 5, 2026
7563bc4
Merge branch 'main' into support_gemma4
Jintao-Huang May 8, 2026
4a74289
Merge branch 'main' into support_gemma4
Jintao-Huang May 8, 2026
0de0ebb
update
Jintao-Huang May 9, 2026
2a81bf0
update
Jintao-Huang May 9, 2026
7e05d3d
update
Jintao-Huang May 9, 2026
8da05df
fix
Jintao-Huang May 9, 2026
d1eff8a
fix
Jintao-Huang May 9, 2026
e545c4f
Merge remote-tracking branch 'refs/remotes/origin/support_gemma4' int…
Jintao-Huang May 9, 2026
0c22e68
fix
Jintao-Huang May 9, 2026
63511bd
Merge remote-tracking branch 'refs/remotes/origin/support_gemma4' int…
Jintao-Huang May 9, 2026
fa5360b
fix
Jintao-Huang May 9, 2026
d25db28
update
Jintao-Huang May 9, 2026
bfbcbc4
update
Jintao-Huang May 9, 2026
2300825
fix
Jintao-Huang May 9, 2026
cda31a5
update
Jintao-Huang May 9, 2026
7e6fb75
update
Jintao-Huang May 9, 2026
e1d0851
update
Jintao-Huang May 9, 2026
0178948
Merge branch 'main' into support_gemma4
Jintao-Huang May 9, 2026
e3cbe5d
update
Jintao-Huang May 9, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
29 changes: 20 additions & 9 deletions src/mcore_bridge/bridge/gpt_bridge.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,9 +40,12 @@ class GPTBridge:
hf_o_proj_key = 'o_proj'
hf_attn_prefix = 'self_attn'
hf_mlp_prefix = 'mlp'
hf_post_attention_layernorm = 'post_attention_layernorm'
hf_gate_key = 'gate.weight'
hf_shared_expert_key = None
hf_expert_bias_key = 'gate.e_score_correction_bias'
additional_dim0_keys = set()
additional_dim1_keys = set()

def __init__(self, config: ModelConfig):
self.config = config
Expand Down Expand Up @@ -124,11 +127,11 @@ def _get_tp_split_dim(self, mg_key: Optional[str]) -> Optional[int]:
'linear_kv_up_proj',
# mtp
'eh_proj',
}
} & self.additional_dim0_keys
if self.config.task_type in {'causal_lm', 'generative_reranker'}:
dim0_keys.add('output_layer')
# RowLinear
dim1_keys = {'out_proj', 'linear_proj', 'linear_fc2'}
dim1_keys = {'out_proj', 'linear_proj', 'linear_fc2'} & self.additional_dim1_keys
if 'lora_A' not in mg_key and 'lora_B' not in mg_key:
key, suffix = mg_key.rsplit('.', 2)[-2:]
if suffix == 'layer_norm_weight':
Expand Down Expand Up @@ -439,7 +442,8 @@ def _set_state_dict(self,
to_mcore: bool,
*,
offset: float = 0,
is_expert: bool = False):
is_expert: bool = False,
_check_mg_param: bool = True):
if '.' in mg_key:
module_key, param_key = mg_key.rsplit('.', 1)
else:
Expand Down Expand Up @@ -487,7 +491,11 @@ def _set_state_dict(self,
else:
mg_param = deep_getattr(sub_module, param_key)
if to_mcore:
assert mg_param is not None, f'mg_module: {mg_module}, mg_key: {mg_key}'
if mg_param is None:
if _check_mg_param:
raise ValueError(f'mg_module: {mg_module}, mg_key: {mg_key}')
else:
return
hf_weight = hf_state_dict[hf_key].load()
if module_key in {
'embedding.word_embeddings', 'output_layer'
Expand Down Expand Up @@ -1587,13 +1595,13 @@ def _set_layer_mlp(self, mg_layer, hf_state_dict, layer_idx: int, to_mcore: bool
hf_state_dict.update(
self._set_moe_state(
mg_mlp, hf_state_dict, f'{self.hf_mlp_prefix}.', layer_idx, to_mcore, is_mtp=is_mtp))
self._set_state_dict(mg_layer, 'pre_mlp_layernorm.weight', hf_state_dict, 'post_attention_layernorm.weight',
to_mcore)
self._set_state_dict(mg_layer, 'pre_mlp_layernorm.weight', hf_state_dict,
f'{self.hf_post_attention_layernorm}.weight', to_mcore)
else:
hf_state_dict.update(
self._set_mlp_state(mg_mlp, hf_state_dict, f'{self.hf_mlp_prefix}.', layer_idx, to_mcore))
self._set_state_dict(mg_layer, 'mlp.linear_fc1.layer_norm_weight', hf_state_dict,
'post_attention_layernorm.weight', to_mcore)
f'{self.hf_post_attention_layernorm}.weight', to_mcore)
return hf_state_dict

def _set_layer_state(self, mg_layer, hf_state_dict, hf_prefix: str, layer_idx: int, to_mcore: bool):
Expand All @@ -1610,13 +1618,16 @@ def _set_layer_state(self, mg_layer, hf_state_dict, hf_prefix: str, layer_idx: i
hf_state_dict = self._add_prefix(hf_state_dict, hf_prefix)
return hf_state_dict

def _set_word_embeddings(self, mg_model, hf_state_dict, to_mcore):
lm_model = getattr(mg_model, 'language_model') if self.is_multimodal else mg_model
self._set_state_dict(lm_model, 'embedding.word_embeddings.weight', hf_state_dict, self.hf_embed_key, to_mcore)

def _convert_pre_process(self, mg_model, hf_state_dict, hf_prefix: str, to_mcore):
if to_mcore:
hf_state_dict = self._remove_prefix(hf_state_dict, hf_prefix)
else:
hf_state_dict = {}
lm_model = getattr(mg_model, 'language_model') if self.is_multimodal else mg_model
self._set_state_dict(lm_model, 'embedding.word_embeddings.weight', hf_state_dict, self.hf_embed_key, to_mcore)
self._set_word_embeddings(mg_model, hf_state_dict, to_mcore)
if self.is_multimodal:
for prefix, mg_prefix in self.module_mapping.items():
mg_module = deep_getattr(mg_model, f'visual.{mg_prefix}')
Expand Down
2 changes: 2 additions & 0 deletions src/mcore_bridge/config/parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -149,6 +149,8 @@ def hf_to_mcore_config(hf_config: PretrainedConfig) -> Dict[str, Any]:
n_shared_experts = res.pop('n_shared_experts')
elif llm_model_type in {'ernie4_5', 'ernie4_5_moe', 'glm4'}:
res['rotary_interleaved'] = True
elif hf_model_type in {'gemma4'}:
res['qk_layernorm'] = True
elif llm_model_type == 'gpt_oss':
res['add_bias_linear'] = True
res['bias_dropout_fusion'] = False
Expand Down
1 change: 1 addition & 0 deletions src/mcore_bridge/model/constant.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ class MLLMModelType:
glm4v_moe = 'glm4v_moe'
kimi_vl = 'kimi_vl'
llama4 = 'llama4'
gemma4 = 'gemma4'

kimi_k25 = 'kimi_k25'

Expand Down
67 changes: 38 additions & 29 deletions src/mcore_bridge/model/gpt_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,9 +100,7 @@ def __init__(
for i in range(len(self.decoder.layers)):
if hasattr(self.decoder.layers[i].self_attention, 'rotary_pos_emb'):
del self.decoder.layers[i].self_attention.rotary_pos_emb
self.attention_scaling = 1.
new_inv_freq, self.attention_scaling = get_rope_inv_freq(config)
self.rotary_pos_emb.inv_freq = new_inv_freq.to(self.rotary_pos_emb.inv_freq.device)
self._set_inv_freq()
if self.config.task_type == 'seq_cls' and self.post_process:
self.output_layer = OutputLayerLinear(
config.hidden_size,
Expand Down Expand Up @@ -212,7 +210,36 @@ def _preprocess(
if decoder_input is not None and self.training and torch.is_grad_enabled() and not decoder_input.requires_grad:
# fix LoRA incompatibility with gradient checkpointing
decoder_input = decoder_input.requires_grad_(True)
rotary_pos_emb, decoder_rotary_pos_emb, rotary_pos_cos, rotary_pos_sin = self._get_rotary_pos_emb(
decoder_input, position_ids, packed_seq_params=packed_seq_params)
Comment on lines +213 to +214
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The inference_context is not passed to the _get_rotary_pos_emb method. This will cause the method to skip critical inference-specific logic, such as utilizing the RoPE cache or correctly calculating the rotary sequence length for flash decoding, which can lead to performance degradation or incorrect results during inference.

Suggested change
rotary_pos_emb, decoder_rotary_pos_emb, rotary_pos_cos, rotary_pos_sin = self._get_rotary_pos_emb(
decoder_input, position_ids, packed_seq_params=packed_seq_params)
rotary_pos_emb, decoder_rotary_pos_emb, rotary_pos_cos, rotary_pos_sin = self._get_rotary_pos_emb(
decoder_input, position_ids, packed_seq_params=packed_seq_params, inference_context=inference_context)


if (in_inference_mode and ((self.config.enable_cuda_graph and self.config.cuda_graph_scope != 'full_iteration')
or self.config.flash_decode) and rotary_pos_cos is not None
and inference_context.is_static_batching()):
current_batch_size = input_ids.shape[0]
sequence_len_offset = torch.tensor(
[inference_context.sequence_len_offset] * current_batch_size,
dtype=torch.int32,
device=rotary_pos_cos.device, # Co-locate this with the rotary tensors
)
else:
sequence_len_offset = None

# Wrap decoder_input to allow the decoder (TransformerBlock) to delete the
# reference held by this caller function, enabling early garbage collection for
# inference. Skip wrapping if decoder_input is logged after decoder completion.
if in_inference_mode and not has_config_logger_enabled(self.config):
decoder_input = WrappedTensor(decoder_input)

return (decoder_input, rotary_pos_emb, decoder_rotary_pos_emb, rotary_pos_cos, rotary_pos_sin,
sequence_len_offset)

def _set_inv_freq(self):
self.attention_scaling = 1.
new_inv_freq, self.attention_scaling = get_rope_inv_freq(self.config)
self.rotary_pos_emb.inv_freq = new_inv_freq.to(self.rotary_pos_emb.inv_freq.device)

def _get_rotary_pos_emb(self, decoder_input, position_ids, packed_seq_params, inference_context=None):
# Rotary positional embeddings (embedding is None for PP intermediate devices)
rotary_pos_emb = None
rotary_pos_cos = None
Expand Down Expand Up @@ -247,26 +274,13 @@ def _preprocess(
rotary_seq_len,
packed_seq=packed_seq,
)
decoder_rotary_pos_emb = rotary_pos_emb
packed_seq = packed_seq_params is not None and packed_seq_params.qkv_format == 'thd'
if self.position_embedding_type == 'rope' and packed_seq and not self.config.apply_rope_fusion:
assert position_ids.shape[0] == 1, f'position_ids.shape: {position_ids.shape}'
decoder_rotary_pos_emb = rotary_pos_emb[position_ids[0]]

if (in_inference_mode and ((self.config.enable_cuda_graph and self.config.cuda_graph_scope != 'full_iteration')
or self.config.flash_decode) and rotary_pos_cos is not None
and inference_context.is_static_batching()):
current_batch_size = input_ids.shape[0]
sequence_len_offset = torch.tensor(
[inference_context.sequence_len_offset] * current_batch_size,
dtype=torch.int32,
device=rotary_pos_cos.device, # Co-locate this with the rotary tensors
)
else:
sequence_len_offset = None

# Wrap decoder_input to allow the decoder (TransformerBlock) to delete the
# reference held by this caller function, enabling early garbage collection for
# inference. Skip wrapping if decoder_input is logged after decoder completion.
if in_inference_mode and not has_config_logger_enabled(self.config):
decoder_input = WrappedTensor(decoder_input)

return decoder_input, rotary_pos_emb, rotary_pos_cos, rotary_pos_sin, sequence_len_offset
return rotary_pos_emb, decoder_rotary_pos_emb, rotary_pos_cos, rotary_pos_sin

# Code borrowed from NVIDIA/Megatron-LM
def forward(
Expand Down Expand Up @@ -296,20 +310,15 @@ def forward(
"""

inference_context = deprecate_inference_params(inference_context, inference_params)

decoder_input, rotary_pos_emb, rotary_pos_cos, rotary_pos_sin, sequence_len_offset = (
# There is a difference in whether rotary_pos_emb can be fused between the decoder and MTP.
decoder_input, rotary_pos_emb, decoder_rotary_pos_emb, rotary_pos_cos, rotary_pos_sin, sequence_len_offset = (
self._preprocess(
input_ids=input_ids,
position_ids=position_ids,
decoder_input=decoder_input,
inference_context=inference_context,
packed_seq_params=packed_seq_params,
))
decoder_rotary_pos_emb = rotary_pos_emb
packed_seq = packed_seq_params is not None and packed_seq_params.qkv_format == 'thd'
if self.position_embedding_type == 'rope' and packed_seq and not self.config.apply_rope_fusion:
assert position_ids.shape[0] == 1, f'position_ids.shape: {position_ids.shape}'
decoder_rotary_pos_emb = rotary_pos_emb[position_ids[0]]

mtp_decoder_input = decoder_input
if self.config.is_multimodal and self.config.mtp_num_layers and decoder_input is None:
Expand Down
8 changes: 5 additions & 3 deletions src/mcore_bridge/model/gpts/minimax_m2.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,9 +27,11 @@ def __init__(
k_layernorm = submodules.k_layernorm
submodules.q_layernorm = IdentityOp
submodules.k_layernorm = IdentityOp
super().__init__(config, submodules, *args, **kwargs)
submodules.q_layernorm = q_layernorm
submodules.k_layernorm = k_layernorm
try:
super().__init__(config, submodules, *args, **kwargs)
finally:
submodules.q_layernorm = q_layernorm
submodules.k_layernorm = k_layernorm
self.q_norm = build_module(
submodules.q_layernorm,
hidden_size=self.hidden_size_per_attention_head * config.num_attention_heads,
Expand Down
2 changes: 1 addition & 1 deletion src/mcore_bridge/model/mm_gpts/__init__.py
Original file line number Diff line number Diff line change
@@ -1,2 +1,2 @@
# Copyright (c) ModelScope Contributors. All rights reserved.
from . import glm, internvl, kimi_vl, llama4, llava, qwen, qwen3_5, qwen3_5_gdn, qwen3_omni, qwen3_vl
from . import gemma4, glm, internvl, kimi_vl, llama4, llava, qwen, qwen3_5, qwen3_5_gdn, qwen3_omni, qwen3_vl
Loading
Loading