Skip to content

Commit e7f1262

Browse files
committed
Address review comments
Signed-off-by: Hollow Man <hollowman@opensuse.org>
1 parent fdb9a7b commit e7f1262

File tree

4 files changed

+36
-6
lines changed

4 files changed

+36
-6
lines changed

tests/utils/test_vllm_weight_name_normalization_on_cpu.py

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,12 @@ class _FakeModel:
3131
def __init__(self):
3232
self.hf_to_vllm_mapper = _FakeMapper(
3333
{
34+
"model.language_model.layers.0.mlp.experts.base_layer.w13_weight": (
35+
"language_model.model.layers.0.mlp.experts.base_layer.w13_weight"
36+
),
37+
"model.language_model.layers.0.mlp.experts.base_layer.w2_weight": (
38+
"language_model.model.layers.0.mlp.experts.base_layer.w2_weight"
39+
),
3440
"model.language_model.layers.0.self_attn.qkv_proj.base_layer.weight": (
3541
"language_model.model.layers.0.self_attn.qkv_proj.base_layer.weight"
3642
),
@@ -89,6 +95,23 @@ def test_normalize_base_sync_weight_names_handles_bridge_inserted_base_layer_on_
8995
]
9096

9197

98+
def test_normalize_base_sync_weight_names_handles_fused_expert_leaf_params():
99+
worker = _make_worker(_FakeModel())
100+
tensor = torch.empty(0)
101+
102+
normalized_weights = worker._normalize_base_sync_weight_names(
103+
[
104+
("model.language_model.layers.0.mlp.experts.w13_weight", tensor),
105+
("model.language_model.layers.0.mlp.experts.base_layer.w2_weight", tensor),
106+
]
107+
)
108+
109+
assert [name for name, _ in normalized_weights] == [
110+
"model.language_model.layers.0.mlp.experts.base_layer.w13_weight",
111+
"model.language_model.layers.0.mlp.experts.base_layer.w2_weight",
112+
]
113+
114+
92115
def test_update_weights_from_ipc_accumulates_lora_tensors_across_buckets(monkeypatch):
93116
import verl.workers.rollout.vllm_rollout.bucketed_weight_transfer as bucketed_weight_transfer
94117

verl/utils/megatron_utils.py

Lines changed: 10 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1342,10 +1342,16 @@ def check_mtp_config(model_config: HFModelConfig, engine_config: McoreEngineConf
13421342
Check and configure MTP (Multi-Token Prediction) settings.
13431343
13441344
Cases:
1345-
- mtp.enable == False and no MTP layers: return directly
1346-
- mtp.enable == False and has MTP layers: set num_nextn_predict_layers = 0
1347-
- mtp.enable == True and has MTP layers: configure override_transformer_config
1348-
- mtp.enable == True and no MTP layers: raise ValueError
1345+
- mtp.enable == False and neither ``num_nextn_predict_layers`` nor
1346+
``mtp_num_hidden_layers`` is enabled on ``hf_config`` /
1347+
``hf_config.text_config``: return directly.
1348+
- mtp.enable == False and MTP layers are configured: zero the first
1349+
supported MTP layer-count field (``num_nextn_predict_layers`` when
1350+
present, otherwise ``mtp_num_hidden_layers``).
1351+
- mtp.enable == True and MTP layers are configured: keep the existing
1352+
layer counts and populate ``override_transformer_config`` as needed.
1353+
- mtp.enable == True and no MTP layers are configured: raise
1354+
``ValueError``.
13491355
"""
13501356
text_hf_config = getattr(model_config.hf_config, "text_config", model_config.hf_config)
13511357
has_mtp = (

verl/workers/megatron_workers.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -157,7 +157,7 @@ def _init_hf_config_and_tf_config(
157157
assert (
158158
getattr(text_hf_config, "num_nextn_predict_layers", 0) > 0
159159
or getattr(text_hf_config, "mtp_num_hidden_layers", 0) > 0
160-
), "MTP requires at least one nextn_predict_layer"
160+
), "MTP requires at least one MTP layer (num_nextn_predict_layers or mtp_num_hidden_layers)"
161161
assert megatron_config.use_mbridge, "MTP requires use_mbridge to be True"
162162
override_transformer_config["mtp_loss_scaling_factor"] = self.config.model.mtp.mtp_loss_scaling_factor
163163
else:

verl/workers/rollout/vllm_rollout/utils.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -188,7 +188,8 @@ def _iter_model_weight_name_candidates(weight_name: str):
188188

189189
@staticmethod
190190
def _is_leaf_weight_or_bias_name(weight_name: str) -> bool:
191-
return weight_name.rsplit(".", 1)[-1] in {"weight", "bias"}
191+
leaf = weight_name.rsplit(".", 1)[-1]
192+
return leaf in {"weight", "bias"} or leaf.endswith(("_weight", "_bias"))
192193

193194
@classmethod
194195
def _strip_bridge_base_layer_from_expert_alias(cls, weight_name: str) -> str:

0 commit comments

Comments
 (0)