Skip to content
22 changes: 18 additions & 4 deletions src/transformers/conversion_mapping.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,11 +70,26 @@ def _build_checkpoint_conversion_mapping():
operations=[MergeModulelist(dim=0), Concatenate(dim=1)],
),
WeightConverter(
source_patterns=["mlp.experts.*.down_proj.weight"],
source_patterns="mlp.experts.*.down_proj.weight",
target_patterns="mlp.experts.down_proj",
operations=[MergeModulelist(dim=0)],
),
],
"lfm2_moe": [
WeightConverter(
source_patterns=[
"feed_forward.experts.*.gate_proj.weight",
"feed_forward.experts.*.up_proj.weight",
],
target_patterns="feed_forward.experts.gate_up_proj",
operations=[MergeModulelist(dim=0), Concatenate(dim=1)],
),
WeightConverter(
source_patterns="feed_forward.experts.*.down_proj.weight",
target_patterns="feed_forward.experts.down_proj",
operations=[MergeModulelist(dim=0)],
),
],
"timm_wrapper": [
# Simply add the prefix `timm_model`
# TODO: Would be probably much cleaner with a `add_prefix` argument in WeightRenaming
Expand Down Expand Up @@ -117,15 +132,14 @@ def _build_checkpoint_conversion_mapping():
),
]

mapping["phimoe"] = mapping["mixtral"].copy()
mapping["phimoe"] = mapping["qwen2_moe"].copy()
mapping["deepseek_v2"] = mapping["qwen2_moe"].copy()
mapping["deepseek_v3"] = mapping["qwen2_moe"].copy()
mapping["dot1"] = mapping["qwen2_moe"].copy()
mapping["ernie_4_5_moe"] = mapping["qwen2_moe"].copy()
mapping["glm4_moe"] = mapping["qwen2_moe"].copy()
mapping["glm4v_moe"] = mapping["qwen2_moe"].copy()
mapping["jamba"] = mapping["qwen2_moe"].copy()
mapping["lfm2_moe"] = mapping["mixtral"].copy()
mapping["jamba"] = mapping["lfm2_moe"].copy()
mapping["long_cat_flash"] = mapping["qwen2_moe"].copy()
mapping["qwen3_moe"] = mapping["qwen2_moe"].copy()
mapping["qwen3_omni_moe"] = mapping["qwen2_moe"].copy()
Expand Down
24 changes: 13 additions & 11 deletions src/transformers/core_model_loading.py
Original file line number Diff line number Diff line change
Expand Up @@ -322,15 +322,15 @@ def __post_init__(self):
# This is ugly but needed for reverse mapping of Qwen2.5!
if r"(?!\.(language_model|visual))" in pattern:
pattern = pattern.replace(r"(?!\.(language_model|visual))", "")
# Allow capturing groups in patterns, i.e. to add a prefix to all keys (e.g. timm_wrapper)
# Allow capturing groups in patterns, i.e. to add/remove a prefix to all keys (e.g. timm_wrapper, sam3)
if r"(.+)" in pattern:
pattern = pattern.replace(r"(.+)", "")
pattern = pattern.replace(r"(.+)", r"\1")
self.target_patterns[i] = pattern

# We also need to check capturing groups in the sources during reverse mapping (e.g. timm_wrapper)
# We also need to check capturing groups in the sources during reverse mapping (e.g. timm_wrapper, sam3)
for i, pattern in enumerate(self.source_patterns):
if r"\1" in pattern:
pattern = pattern.replace(r"\1", "")
pattern = pattern.replace(r"\1", r"(.+)")
self.source_patterns[i] = pattern

def add_tensor(self, target_key: str, source_key: str, source_pattern: str, future: Future):
Expand Down Expand Up @@ -628,8 +628,11 @@ def repl(m, repl_map: dict[str, str]) -> str:
name = matched_groups[0]
replacement = repl_map[name]
# Allow capturing groups in patterns, i.e. to add a prefix to all keys (e.g. timm_wrapper)
if r"\1" in replacement and len(m.groups()) > 1:
replacement = replacement.replace(r"\1", m.group(1))
if r"\1" in replacement:
# If we find a capturing group, the parenthesized group corresponding is the one right after the named
# group we matched, as it's part of that named group
group_idx_to_replace = m.re.groupindex[name] + 1
replacement = replacement.replace(r"\1", m.group(group_idx_to_replace))

return replacement

Expand All @@ -647,17 +650,16 @@ def rename_source_key(
Rename a source key given all the renaming and weight conversion patterns we have. Also takes care of adding/removing
the base model prefix during loading if necesary.
"""
# 1. apply all renamings
renamed_key = rename_alternation.sub(lambda m: repl(m, rename_by_group), source_key).replace("\\", "")
# 1. apply all renamings (we need to replace only the first match of the alternation if multiple matches, so count=1)
renamed_key = rename_alternation.sub(lambda m: repl(m, rename_by_group), source_key, count=1)

# 2. apply renaming through weight conversions on the key if we have any WeightConverter
matched_converter_pattern = (
weight_pattern_alternation.search(renamed_key) if weight_pattern_alternation is not None else None
)
if matched_converter_pattern is not None:
renamed_key = weight_pattern_alternation.sub(lambda m: repl(m, weight_pattern_by_group), renamed_key).replace(
"\\", ""
)
# we need to replace only the first match of the alternation if multiple matches, so count=1
renamed_key = weight_pattern_alternation.sub(lambda m: repl(m, weight_pattern_by_group), renamed_key, count=1)

# 3. check if we need to add or remove prefix if necesary (only during loading, not saving)
if prefix is not None and meta_state_dict is not None:
Expand Down
4 changes: 4 additions & 0 deletions src/transformers/models/lfm2_moe/configuration_lfm2_moe.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,8 @@ class Lfm2MoeConfig(PreTrainedConfig):
with longer `max_position_embeddings`.
max_position_embeddings (`int`, *optional*, defaults to 128000):
The maximum sequence length that this model might ever be used with.
initializer_range (`float`, *optional*, defaults to 0.02):
The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
use_cache (`bool`, *optional*, defaults to `True`):
Whether or not the model should return the last key/values attentions (not used by all models). Only
relevant if `config.is_decoder=True`.
Expand Down Expand Up @@ -117,6 +119,7 @@ def __init__(
tie_word_embeddings: bool = True,
rope_parameters: RopeParameters = None,
max_position_embeddings: int = 128_000,
initializer_range: float = 0.02,
use_cache: bool = True,
norm_eps: float = 0.00001,
num_attention_heads: int = 32,
Expand All @@ -140,6 +143,7 @@ def __init__(
rope_scaling = kwargs.pop("rope_scaling", None)
self.rope_parameters = rope_scaling or rope_parameters
self.max_position_embeddings = max_position_embeddings
self.initializer_range = initializer_range
self.use_cache = use_cache
self.norm_eps = norm_eps

Expand Down
4 changes: 1 addition & 3 deletions src/transformers/models/lfm2_moe/modeling_lfm2_moe.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,6 @@
from torch import nn

from ... import initialization as init
from ...activations import ACT2FN
from ...cache_utils import Cache
from ...generation import GenerationMixin
from ...integrations import use_kernel_forward_from_hub
Expand Down Expand Up @@ -156,7 +155,6 @@ def __init__(self, config):
self.intermediate_dim = config.moe_intermediate_size
self.gate_up_proj = nn.Parameter(torch.empty(self.num_experts, 2 * self.intermediate_dim, self.hidden_dim))
self.down_proj = nn.Parameter(torch.empty(self.num_experts, self.hidden_dim, self.intermediate_dim))
self.act_fn = ACT2FN[config.hidden_act]

def forward(
self,
Expand All @@ -178,7 +176,7 @@ def forward(
_, token_idx = torch.where(expert_mask[expert_idx])
current_state = hidden_states[token_idx]
gate, up = nn.functional.linear(current_state, self.gate_up_proj[expert_idx]).chunk(2, dim=-1)
current_hidden_states = self.act_fn(gate) * up
current_hidden_states = gate * up
current_hidden_states = nn.functional.linear(current_hidden_states, self.down_proj[expert_idx])
current_hidden_states = current_hidden_states * top_k_weights[token_idx, expert_idx, None]
final_hidden_states.index_add_(0, token_idx, current_hidden_states.to(final_hidden_states.dtype))
Expand Down
31 changes: 30 additions & 1 deletion src/transformers/models/lfm2_moe/modular_lfm2_moe.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,36 @@ def __init__(self, config: Lfm2MoeConfig, intermediate_size: Optional[int] = Non


class Lfm2MoeExperts(Qwen2MoeExperts):
pass
def __init__(self, config):
super().__init__(config)
del self.act_fn

def forward(
self,
hidden_states: torch.Tensor,
top_k_index: torch.Tensor,
top_k_weights: torch.Tensor,
) -> torch.Tensor:
final_hidden_states = torch.zeros_like(hidden_states)
num_experts = top_k_weights.shape[1]
with torch.no_grad():
expert_mask = torch.nn.functional.one_hot(top_k_index, num_classes=num_experts + 1)
expert_mask = expert_mask.permute(2, 1, 0)
expert_hit = torch.greater(expert_mask.sum(dim=(-1, -2)), 0).nonzero()

for expert_idx in expert_hit:
expert_idx = expert_idx[0]
if expert_idx == num_experts:
continue
_, token_idx = torch.where(expert_mask[expert_idx])
current_state = hidden_states[token_idx]
gate, up = nn.functional.linear(current_state, self.gate_up_proj[expert_idx]).chunk(2, dim=-1)
current_hidden_states = gate * up
current_hidden_states = nn.functional.linear(current_hidden_states, self.down_proj[expert_idx])
current_hidden_states = current_hidden_states * top_k_weights[token_idx, expert_idx, None]
final_hidden_states.index_add_(0, token_idx, current_hidden_states.to(final_hidden_states.dtype))

return final_hidden_states


class Lfm2MoeSparseMoeBlock(nn.Module):
Expand Down
4 changes: 3 additions & 1 deletion src/transformers/models/sam3/modeling_sam3.py
Original file line number Diff line number Diff line change
Expand Up @@ -2089,7 +2089,9 @@ def _embed_pixels(

class Sam3Model(Sam3PreTrainedModel):
input_modalities = ["image", "text"]
_checkpoint_conversion_mapping = {"detector_model.": ""}
_checkpoint_conversion_mapping = {
r"detector_model.(.+)": r"\1"
} # the regex allows to remove the prefix, and add it back in revert mode
_keys_to_ignore_on_load_unexpected = [
r"^tracker_model.",
r"^tracker_neck.",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -768,9 +768,10 @@ class Sam3TrackerModel(Sam3TrackerPreTrainedModel):
"occlusion_spatial_embedding_parameter",
]
_checkpoint_conversion_mapping = {
"tracker_model.": "",
"detector_model.vision_encoder.backbone.": "vision_encoder.backbone.",
"tracker_neck.": "vision_encoder.neck.",
# This one needs to be last so that in reverse mode, the other can match before
r"tracker_model.(.+)": r"\1", # the regex allows to remove the prefix, and add it back in revert mode
}

def __init__(self, config: Sam3TrackerConfig):
Expand Down
3 changes: 2 additions & 1 deletion src/transformers/models/sam3_tracker/modular_sam3_tracker.py
Original file line number Diff line number Diff line change
Expand Up @@ -180,9 +180,10 @@ class Sam3TrackerMaskDecoder(Sam2MaskDecoder):

class Sam3TrackerModel(Sam2Model):
_checkpoint_conversion_mapping = {
"tracker_model.": "",
"detector_model.vision_encoder.backbone.": "vision_encoder.backbone.",
"tracker_neck.": "vision_encoder.neck.",
# This one needs to be last so that in reverse mode, the other can match before
r"tracker_model.(.+)": r"\1", # the regex allows to remove the prefix, and add it back in revert mode
}
_keys_to_ignore_on_load_unexpected = [
r"^detector_model.",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1570,9 +1570,10 @@ class Sam3TrackerVideoModel(Sam3TrackerVideoPreTrainedModel):
_tied_weights_keys = {}
_keys_to_ignore_on_load_missing = []
_checkpoint_conversion_mapping = {
"tracker_model.": "",
"detector_model.vision_encoder.backbone.": "vision_encoder.backbone.",
"tracker_neck.": "vision_encoder.neck.",
# This one needs to be last so that in reverse mode, the other can match before
r"tracker_model.(.+)": r"\1", # the regex allows to remove the prefix, and add it back in revert mode
}

def __init__(self, config: Sam3TrackerVideoConfig, remove_vision_encoder: bool = False):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -456,9 +456,10 @@ class Sam3TrackerVideoMaskDecoder(Sam2VideoMaskDecoder):

class Sam3TrackerVideoModel(Sam2VideoModel):
_checkpoint_conversion_mapping = {
"tracker_model.": "",
"detector_model.vision_encoder.backbone.": "vision_encoder.backbone.",
"tracker_neck.": "vision_encoder.neck.",
# This one needs to be last so that in reverse mode, the other can match before
r"tracker_model.(.+)": r"\1", # the regex allows to remove the prefix, and add it back in revert mode
}
_keys_to_ignore_on_load_unexpected = [r"^detector_model."]
_tied_weights_keys = {}
Expand Down
4 changes: 1 addition & 3 deletions tests/test_modeling_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -4058,7 +4058,6 @@ def test_tp_plan_matches_params(self):
len(unused_entries) == 0, f"The following entries of the TP-plan are not valid: {unused_entries}"
)

@unittest.skip("Some models have wrong mappings....")
def test_reverse_loading_mapping(self):
"""Make sure we can load and save correctly the models having any weight renaming mapping or weight conversion
mapping.
Expand All @@ -4073,7 +4072,7 @@ def test_reverse_loading_mapping(self):

# Some MoE models alternate between a classic MLP and a MoE layer, in which case we want to have at
# lest one MoE layer here to check the mapping
config_to_set = config.get_text_config()
config_to_set = config.get_text_config(decoder=True)
config_to_set.first_k_dense_replace = 1 # means that the first layer (idx 0) will be MLP, then MoE
config_to_set.moe_layer_start_index = 1 # same as above but for Ernie 4.5...
config_to_set.mlp_only_layers = [0] # same but for qwens
Expand Down Expand Up @@ -4137,7 +4136,6 @@ def test_reverse_loading_mapping(self):
# Make sure both saved state_dict are identical
self.assertTrue(compare_state_dicts(state_dict_saved_from_init, state_dict_saved_from_pretrained))

@unittest.skip("Some models have wrong mappings....")
def test_can_load_from_already_mapped_keys(self):
"""Test that we can correctly reload a model if we chose `save_original_format=False` in `save_pretrained`,
i.e. we do not reapply weight conversions when reloading if it was saved correctly already.
Expand Down