Skip to content

Commit 8b29333

Browse files
committed
[megatron] fix: Qwen3.5 LoRA support
- move `.base_layer` normalization out of the Megatron sender and into the vLLM receiver - make vLLM weight-name normalization robust to packed modules and fused MoE logical aliases - fix bucketed LoRA IPC updates so multi-bucket adapters are applied only after the final bucket arrives - avoid incorrect engine offload behavior when param offload is disabled - improve MTP compatibility for nested HF text configs and newer Megatron-LM APIs - add targeted regression coverage for the new weight-sync and worker behaviors The previous weight sync path relied on sender-side name rewriting, including hard-coded `.base_layer` handling. That made the Megatron/vLLM boundary brittle for newer models, especially packed projections and fused MoE modules. In addition, the async LoRA update path assumed each adapter arrived in a single IPC bucket, which is not guaranteed by the bucketed transport. That could produce incomplete `add_lora` requests when LoRA tensors were split across multiple buckets. There were also a few compatibility issues around: - Qwen-style nested `text_config` MTP fields - newer Megatron-LM `process_mtp_loss` API - actor engine offload behavior when automatic reload is not enabled Instead of mutating exported Megatron parameter names on the sender side, the vLLM colocate worker now resolves incoming names against the live vLLM parameter namespace. This includes: - generic add/remove resolution for `.base_layer` leaf params (`weight` / `bias`) - stripping Bridge-inserted `.base_layer` from non-leaf fused-MoE logical aliases - packed-owner lookup for aliases such as `q_proj -> qkv_proj` This keeps the sender simpler and lets the receiver normalize names based on the actual loaded vLLM model structure. `BucketedWeightReceiver` now forwards `is_last` to the callback. The async rollout weight update path uses that signal to: - keep standard base-weight loading per bucket - accumulate LoRA tensors across buckets - call `add_lora` only once, after the final bucket is received This aligns VERL's bucketed transport semantics with vLLM's expectation that one `add_lora` request contains a complete adapter tensor dict. The old hard-coded stacked-parameter suffix list was removed. `megatron_peft_utils` now exposes generic helpers to: - add `.base_layer` - remove `.base_layer` - resolve the correct name by probing the target namespace The Megatron-to-HF module mapping was also expanded for GDN modules, including: - `in_proj -> [in_proj_qkv, in_proj_z, in_proj_b, in_proj_a]` - `out_proj -> [out_proj]` MTP checks and disabling logic now work with nested HF text configs as well as configs that use `mtp_num_hidden_layers` instead of only `num_nextn_predict_layers`. The Megatron MTP patch also prefers the newer upstream `process_mtp_loss` helper when available, while preserving a fallback path for older Megatron-LM versions. - `ActorRolloutRefWorker.update_weights` now only offloads the actor engine back to CPU when param offload is actually enabled. - `vision_config` is now preserved in Megatron checkpoint manager backup state. Signed-off-by: Hollow Man <hollowman@opensuse.org>
1 parent 809f2d8 commit 8b29333

12 files changed

Lines changed: 531 additions & 74 deletions

tests/utils/test_bucketed_weight_transfer.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -86,7 +86,12 @@ def _receiver_fn(zmq_handle, use_shm, result_queue):
8686
use_shm=use_shm,
8787
)
8888
received = []
89-
receiver.receive_weights(on_bucket_received=lambda w: received.extend([(name, t.clone()) for name, t in w]))
89+
90+
def on_bucket_received(weights, *, is_last):
91+
del is_last
92+
received.extend([(name, t.clone()) for name, t in weights])
93+
94+
receiver.receive_weights(on_bucket_received=on_bucket_received)
9095
# Only send lightweight metadata + checksum back through the queue
9196
summaries = [(name, t.dtype, tuple(t.shape), t.float().sum().item()) for name, t in received]
9297
result_queue.put(summaries)
Lines changed: 54 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,54 @@
1+
# Copyright 2026 Bytedance Ltd. and/or its affiliates
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
from verl.utils.megatron_peft_utils import convert_megatron_to_hf_target_modules, resolve_base_layer_name
16+
17+
18+
def test_convert_megatron_to_hf_target_modules_expands_gdn_in_proj():
19+
converted = convert_megatron_to_hf_target_modules(["in_proj", "out_proj"])
20+
21+
assert converted == [
22+
"in_proj_qkv",
23+
"in_proj_z",
24+
"in_proj_b",
25+
"in_proj_a",
26+
"out_proj",
27+
]
28+
29+
30+
def test_resolve_base_layer_name_adds_suffix_when_target_requires_it():
31+
resolved_name = resolve_base_layer_name(
32+
"model.layers.0.self_attn.q_proj.weight",
33+
exists=lambda candidate: candidate == "model.layers.0.self_attn.q_proj.base_layer.weight",
34+
)
35+
36+
assert resolved_name == "model.layers.0.self_attn.q_proj.base_layer.weight"
37+
38+
39+
def test_resolve_base_layer_name_removes_suffix_when_target_does_not_use_it():
40+
resolved_name = resolve_base_layer_name(
41+
"model.visual.merger.linear_fc1.base_layer.weight",
42+
exists=lambda candidate: candidate == "model.visual.merger.linear_fc1.weight",
43+
)
44+
45+
assert resolved_name == "model.visual.merger.linear_fc1.weight"
46+
47+
48+
def test_resolve_base_layer_name_keeps_existing_name():
49+
resolved_name = resolve_base_layer_name(
50+
"model.visual.merger.linear_fc1.weight",
51+
exists=lambda candidate: candidate == "model.visual.merger.linear_fc1.weight",
52+
)
53+
54+
assert resolved_name == "model.visual.merger.linear_fc1.weight"
Lines changed: 135 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,135 @@
1+
# Copyright 2026 Bytedance Ltd. and/or its affiliates
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
from types import SimpleNamespace
16+
17+
import torch
18+
19+
from verl.workers.rollout.vllm_rollout.utils import VLLM_LORA_INT_ID, vLLMColocateWorkerExtension
20+
21+
22+
class _FakeMapper:
23+
def __init__(self, mapping: dict[str, str]):
24+
self.mapping = mapping
25+
26+
def apply_list(self, names: list[str]) -> list[str]:
27+
return [self.mapping.get(name, name) for name in names]
28+
29+
30+
class _FakeModel:
31+
def __init__(self):
32+
self.hf_to_vllm_mapper = _FakeMapper(
33+
{
34+
"model.language_model.layers.0.self_attn.qkv_proj.base_layer.weight": (
35+
"language_model.model.layers.0.self_attn.qkv_proj.base_layer.weight"
36+
),
37+
}
38+
)
39+
self.packed_modules_mapping = {
40+
"qkv_proj": ["q_proj", "k_proj", "v_proj"],
41+
}
42+
43+
def named_parameters(self, remove_duplicate: bool = False):
44+
del remove_duplicate
45+
yield "language_model.model.layers.0.mlp.experts.base_layer.w13_weight", torch.empty(0)
46+
yield "language_model.model.layers.0.mlp.experts.base_layer.w2_weight", torch.empty(0)
47+
yield "language_model.model.layers.0.self_attn.qkv_proj.base_layer.weight", torch.empty(0)
48+
49+
50+
def _make_worker(model):
51+
worker = object.__new__(vLLMColocateWorkerExtension)
52+
worker.model_runner = SimpleNamespace(model=model)
53+
return worker
54+
55+
56+
def test_normalize_base_sync_weight_names_preserves_expert_logical_aliases():
57+
worker = _make_worker(_FakeModel())
58+
tensor = torch.empty(0)
59+
60+
normalized_weights = worker._normalize_base_sync_weight_names(
61+
[
62+
("model.language_model.layers.0.mlp.experts.gate_up_proj", tensor),
63+
("model.language_model.layers.0.mlp.experts.down_proj", tensor),
64+
("model.language_model.layers.0.self_attn.q_proj.weight", tensor),
65+
]
66+
)
67+
68+
assert [name for name, _ in normalized_weights] == [
69+
"model.language_model.layers.0.mlp.experts.gate_up_proj",
70+
"model.language_model.layers.0.mlp.experts.down_proj",
71+
"model.language_model.layers.0.self_attn.q_proj.base_layer.weight",
72+
]
73+
74+
75+
def test_normalize_base_sync_weight_names_handles_bridge_inserted_base_layer_on_fused_experts():
76+
worker = _make_worker(_FakeModel())
77+
tensor = torch.empty(0)
78+
79+
normalized_weights = worker._normalize_base_sync_weight_names(
80+
[
81+
("model.language_model.layers.0.mlp.experts.base_layer.gate_up_proj", tensor),
82+
("model.language_model.layers.0.mlp.experts.base_layer.down_proj", tensor),
83+
]
84+
)
85+
86+
assert [name for name, _ in normalized_weights] == [
87+
"model.language_model.layers.0.mlp.experts.gate_up_proj",
88+
"model.language_model.layers.0.mlp.experts.down_proj",
89+
]
90+
91+
92+
def test_update_weights_from_ipc_accumulates_lora_tensors_across_buckets(monkeypatch):
93+
import verl.workers.rollout.vllm_rollout.bucketed_weight_transfer as bucketed_weight_transfer
94+
95+
class _FakeBucketReceiver:
96+
def __init__(self, zmq_handle, device, use_shm):
97+
del zmq_handle, device, use_shm
98+
99+
def receive_weights(self, on_bucket_received):
100+
on_bucket_received(
101+
[("layers.0.self_attn.q_proj.lora_A.weight", torch.ones(1))],
102+
is_last=False,
103+
)
104+
on_bucket_received(
105+
[("layers.0.self_attn.q_proj.lora_B.weight", torch.zeros(1))],
106+
is_last=True,
107+
)
108+
109+
monkeypatch.setattr(bucketed_weight_transfer, "BucketedWeightReceiver", _FakeBucketReceiver)
110+
111+
worker = _make_worker(_FakeModel())
112+
worker.model_runner.vllm_config = SimpleNamespace()
113+
worker.device = torch.device("cpu")
114+
worker.local_rank = 0
115+
worker._is_qat_model = False
116+
worker._get_zmq_handle = lambda: "ipc:///tmp/test-bucketed-lora.sock"
117+
118+
removed_loras = []
119+
added_requests = []
120+
worker.remove_lora = removed_loras.append
121+
122+
def _add_lora(lora_request):
123+
added_requests.append(lora_request)
124+
return True
125+
126+
worker.add_lora = _add_lora
127+
128+
worker.update_weights_from_ipc(peft_config={"r": 1}, base_sync_done=True)
129+
130+
assert removed_loras == [VLLM_LORA_INT_ID]
131+
assert len(added_requests) == 1
132+
assert set(added_requests[0].lora_tensors) == {
133+
"layers.0.self_attn.q_proj.lora_A.weight",
134+
"layers.0.self_attn.q_proj.lora_B.weight",
135+
}
Lines changed: 97 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,97 @@
1+
# Copyright 2026 Bytedance Ltd. and/or its affiliates
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
import asyncio
16+
from types import SimpleNamespace
17+
18+
import torch
19+
from omegaconf import OmegaConf
20+
21+
from verl.workers.engine_workers import ActorRolloutRefWorker
22+
23+
24+
class _DummyEngine:
25+
def __init__(self, *, is_param_offload_enabled: bool):
26+
self.is_param_offload_enabled = is_param_offload_enabled
27+
self.get_per_tensor_param_calls = []
28+
self.to_calls = []
29+
30+
def get_per_tensor_param(self, **kwargs):
31+
self.get_per_tensor_param_calls.append(kwargs)
32+
33+
def _weights():
34+
yield ("model.embed_tokens.weight", torch.tensor([1.0]))
35+
36+
return _weights(), None
37+
38+
def to(self, device: str, model: bool = True, optimizer: bool = True, grad: bool = True):
39+
self.to_calls.append((device, model, optimizer, grad))
40+
41+
42+
class _DummyRollout:
43+
def __init__(self):
44+
self.sleep_level = 0
45+
self.update_calls = []
46+
self.resume_calls = []
47+
48+
async def resume(self, tags):
49+
self.resume_calls.append(tags)
50+
51+
async def update_weights(self, weights, **kwargs):
52+
self.update_calls.append({"weights": list(weights), **kwargs})
53+
54+
55+
def _build_worker(*, is_param_offload_enabled: bool):
56+
worker = object.__new__(ActorRolloutRefWorker)
57+
worker.config = OmegaConf.create(
58+
{
59+
"rollout": {
60+
"checkpoint_engine": {"backend": "naive"},
61+
"free_cache_engine": False,
62+
}
63+
}
64+
)
65+
worker.actor = SimpleNamespace(engine=_DummyEngine(is_param_offload_enabled=is_param_offload_enabled))
66+
worker.rollout = _DummyRollout()
67+
worker.base_sync_done = True
68+
worker.layered_summon = False
69+
worker.peft_merge = False
70+
return worker
71+
72+
73+
def test_update_weights_does_not_offload_actor_when_param_offload_disabled(monkeypatch):
74+
monkeypatch.setattr("verl.workers.engine_workers.set_expandable_segments", lambda *_: None)
75+
monkeypatch.setattr("verl.workers.engine_workers.log_gpu_memory_usage", lambda *args, **kwargs: None)
76+
monkeypatch.setattr("verl.workers.engine_workers.aggressive_empty_cache", lambda *args, **kwargs: None)
77+
78+
worker = _build_worker(is_param_offload_enabled=False)
79+
80+
asyncio.run(ActorRolloutRefWorker.update_weights(worker))
81+
82+
assert worker.actor.engine.to_calls == []
83+
assert worker.actor.engine.get_per_tensor_param_calls == [{"layered_summon": False, "base_sync_done": True}]
84+
assert len(worker.rollout.update_calls) == 1
85+
assert worker.rollout.update_calls[0]["base_sync_done"] is True
86+
87+
88+
def test_update_weights_offloads_actor_when_param_offload_enabled(monkeypatch):
89+
monkeypatch.setattr("verl.workers.engine_workers.set_expandable_segments", lambda *_: None)
90+
monkeypatch.setattr("verl.workers.engine_workers.log_gpu_memory_usage", lambda *args, **kwargs: None)
91+
monkeypatch.setattr("verl.workers.engine_workers.aggressive_empty_cache", lambda *args, **kwargs: None)
92+
93+
worker = _build_worker(is_param_offload_enabled=True)
94+
95+
asyncio.run(ActorRolloutRefWorker.update_weights(worker))
96+
97+
assert worker.actor.engine.to_calls == [("cpu", True, False, False)]

verl/utils/checkpoint/megatron_checkpoint_manager.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -645,6 +645,7 @@ def save_checkpoint(self, local_path: str, hdfs_path: str = None, global_step: i
645645
"grad_sync_func",
646646
"param_sync_func",
647647
"generation_config",
648+
"vision_config",
648649
"_pg_collection",
649650
]
650651
backup = {}

0 commit comments

Comments
 (0)