Skip to content

Commit 32c4e78

Browse files
committed
Address review comments
Signed-off-by: Hollow Man <hollowman@opensuse.org>
1 parent 8b29333 commit 32c4e78

File tree

4 files changed

+36
-6
lines changed

4 files changed

+36
-6
lines changed

tests/utils/test_vllm_weight_name_normalization_on_cpu.py

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,12 @@ class _FakeModel:
3131
def __init__(self):
3232
self.hf_to_vllm_mapper = _FakeMapper(
3333
{
34+
"model.language_model.layers.0.mlp.experts.base_layer.w13_weight": (
35+
"language_model.model.layers.0.mlp.experts.base_layer.w13_weight"
36+
),
37+
"model.language_model.layers.0.mlp.experts.base_layer.w2_weight": (
38+
"language_model.model.layers.0.mlp.experts.base_layer.w2_weight"
39+
),
3440
"model.language_model.layers.0.self_attn.qkv_proj.base_layer.weight": (
3541
"language_model.model.layers.0.self_attn.qkv_proj.base_layer.weight"
3642
),
@@ -89,6 +95,23 @@ def test_normalize_base_sync_weight_names_handles_bridge_inserted_base_layer_on_
8995
]
9096

9197

98+
def test_normalize_base_sync_weight_names_handles_fused_expert_leaf_params():
99+
worker = _make_worker(_FakeModel())
100+
tensor = torch.empty(0)
101+
102+
normalized_weights = worker._normalize_base_sync_weight_names(
103+
[
104+
("model.language_model.layers.0.mlp.experts.w13_weight", tensor),
105+
("model.language_model.layers.0.mlp.experts.base_layer.w2_weight", tensor),
106+
]
107+
)
108+
109+
assert [name for name, _ in normalized_weights] == [
110+
"model.language_model.layers.0.mlp.experts.base_layer.w13_weight",
111+
"model.language_model.layers.0.mlp.experts.base_layer.w2_weight",
112+
]
113+
114+
92115
def test_update_weights_from_ipc_accumulates_lora_tensors_across_buckets(monkeypatch):
93116
import verl.workers.rollout.vllm_rollout.bucketed_weight_transfer as bucketed_weight_transfer
94117

verl/utils/megatron_utils.py

Lines changed: 10 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1479,10 +1479,16 @@ def check_mtp_config(model_config: HFModelConfig, engine_config: McoreEngineConf
14791479
Check and configure MTP (Multi-Token Prediction) settings.
14801480
14811481
Cases:
1482-
- mtp.enable == False and no MTP layers: return directly
1483-
- mtp.enable == False and has MTP layers: set num_nextn_predict_layers = 0
1484-
- mtp.enable == True and has MTP layers: configure override_transformer_config
1485-
- mtp.enable == True and no MTP layers: raise ValueError
1482+
- mtp.enable == False and neither ``num_nextn_predict_layers`` nor
1483+
``mtp_num_hidden_layers`` is enabled on ``hf_config`` /
1484+
``hf_config.text_config``: return directly.
1485+
- mtp.enable == False and MTP layers are configured: zero the first
1486+
supported MTP layer-count field (``num_nextn_predict_layers`` when
1487+
present, otherwise ``mtp_num_hidden_layers``).
1488+
- mtp.enable == True and MTP layers are configured: keep the existing
1489+
layer counts and populate ``override_transformer_config`` as needed.
1490+
- mtp.enable == True and no MTP layers are configured: raise
1491+
``ValueError``.
14861492
"""
14871493
text_hf_config = getattr(model_config.hf_config, "text_config", model_config.hf_config)
14881494
has_mtp = (

verl/workers/megatron_workers.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -159,7 +159,7 @@ def _init_hf_config_and_tf_config(
159159
assert (
160160
getattr(text_hf_config, "num_nextn_predict_layers", 0) > 0
161161
or getattr(text_hf_config, "mtp_num_hidden_layers", 0) > 0
162-
), "MTP requires at least one nextn_predict_layer"
162+
), "MTP requires at least one MTP layer (num_nextn_predict_layers or mtp_num_hidden_layers)"
163163
assert megatron_config.use_mbridge, "MTP requires use_mbridge to be True"
164164
override_transformer_config["mtp_loss_scaling_factor"] = self.config.model.mtp.mtp_loss_scaling_factor
165165
else:

verl/workers/rollout/vllm_rollout/utils.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -195,7 +195,8 @@ def _iter_model_weight_name_candidates(weight_name: str):
195195

196196
@staticmethod
197197
def _is_leaf_weight_or_bias_name(weight_name: str) -> bool:
198-
return weight_name.rsplit(".", 1)[-1] in {"weight", "bias"}
198+
leaf = weight_name.rsplit(".", 1)[-1]
199+
return leaf in {"weight", "bias"} or leaf.endswith(("_weight", "_bias"))
199200

200201
@classmethod
201202
def _strip_bridge_base_layer_from_expert_alias(cls, weight_name: str) -> str:

0 commit comments

Comments
 (0)