Skip to content

Commit d7c88e5

Browse files
authored
[veomni] feat: support offloading/loading the veomni model/optimizer (verl-project#4916)
### What does this PR do? This PR adds support for offloading both the model and optimizer (in veomni style) to CPU, as well as onloading them back to the device. Additionally, it includes two model conversion scripts required by veomni: - `moe_merge.py`: Converts models from Hugging Face (HF) format into a format compatible with veomni. - `moe_split.py`: Converts checkpoints generated by veomni training back into HF format. ### Checklist Before Starting - [ ] Search for similar PRs. Paste at least one query link here: ... - [ ] Format the PR title as `[{modules}] {type}: {description}` (This will be checked by the CI) - `{modules}` include `fsdp`, `megatron`, `veomni`, `sglang`, `vllm`, `rollout`, `trainer`, `ci`, `training_utils`, `recipe`, `hardware`, `deployment`, `ray`, `worker`, `single_controller`, `misc`, `perf`, `model`, `algo`, `env`, `tool`, `ckpt`, `doc`, `data`, `cfg`, `reward` - If this PR involves multiple modules, separate them with `,` like `[megatron, fsdp, doc]` - `{type}` is in `feat`, `fix`, `refactor`, `chore`, `test` - If this PR breaks any API (CLI arguments, config, function signature, etc.), add `[BREAKING]` to the beginning of the title. - Example: `[BREAKING][fsdp, megatron] feat: dynamic batching` ### Test > For changes that can not be tested by CI (e.g., algorithm implementation, new model support), validate by experiment(s) and show results like training curve plots, evaluation results, etc. ### API and Usage Example > Demonstrate how the API changes if any, and provide usage example(s) if possible. ```python # Add code snippet or script demonstrating how to use this ``` ### Design & Code Changes > Demonstrate the high-level design if this PR is complex, and list the specific changes. ### Checklist Before Submitting > [!IMPORTANT] > Please check all the following items before requesting a review, otherwise the reviewer might deprioritize this PR for review. - [ ] Read the [Contribute Guide](https://github.com/volcengine/verl/blob/main/CONTRIBUTING.md). - [ ] Apply [pre-commit checks](https://github.com/volcengine/verl/blob/main/CONTRIBUTING.md#code-linting-and-formatting): `pre-commit install && pre-commit run --all-files --show-diff-on-failure --color=always` - [ ] Add / Update [the documentation](https://github.com/volcengine/verl/tree/main/docs). - [ ] Add unit or end-to-end test(s) to [the CI workflow](https://github.com/volcengine/verl/tree/main/.github/workflows) to cover all the code. If not feasible, explain why: ... - [ ] Once your PR is ready for CI, send a message in [the `ci-request` channel](https://verl-project.slack.com/archives/C091TCESWB1) in [the `verl` Slack workspace](https://join.slack.com/t/verl-project/shared_invite/zt-3855yhg8g-CTkqXu~hKojPCmo7k_yXTQ). (If not accessible, please try [the Feishu group (飞书群)](https://applink.larkoffice.com/client/chat/chatter/add_by_link?link_token=772jd4f1-cd91-441e-a820-498c6614126a).) - [ ] If your PR is related to the `recipe` submodule, please also update the reference to the submodule commit via `git submodule update --remote` or `cd recipe && git pull origin main`.
1 parent 750d5b5 commit d7c88e5

5 files changed

Lines changed: 336 additions & 1 deletion

File tree

scripts/veomni/moe_merge.py

Lines changed: 121 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,121 @@
1+
# Copyright 2025 Bytedance Ltd. and/or its affiliates
2+
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
"""
15+
Merge individual MoE expert weights into stacked tensors for efficient loading.
16+
17+
This script takes a HuggingFace checkpoint with individual expert weights
18+
(e.g., model.layers.{i}.mlp.experts.{j}.gate_proj.weight) and merges them
19+
into stacked tensors (e.g., model.layers.{i}.mlp.experts.gate_proj) for
20+
faster loading and better memory efficiency in VeOmni.
21+
22+
The merging process:
23+
1. Loads individual expert weights from the HF checkpoint
24+
2. Stacks them into single tensors for each projection type
25+
3. Handles all three projection types: gate_proj, up_proj, down_proj
26+
4. Supports both Qwen3-MoE (num_experts) and DeepSeek (n_routed_experts) formats
27+
5. Handles models with initial dense layers (first_k_dense_replace)
28+
29+
Usage: python moe_merge.py --raw_hf_path <input_checkpoint> --merge_hf_path <output_dir>
30+
"""
31+
32+
import os
33+
from argparse import ArgumentParser
34+
from dataclasses import dataclass
35+
from glob import glob
36+
from typing import Generator
37+
38+
import torch
39+
from safetensors.torch import safe_open
40+
from tqdm import tqdm
41+
from transformers import AutoConfig
42+
from veomni.models import build_tokenizer, save_model_weights
43+
44+
45+
@dataclass
46+
class StateDictIterator:
47+
filepath: str
48+
49+
def __iter__(self) -> Generator[tuple[str, "torch.Tensor"], None, None]:
50+
if self.filepath.endswith(".safetensors"):
51+
with safe_open(self.filepath, framework="pt", device="cpu") as f:
52+
for key in f.keys():
53+
yield key, f.get_tensor(key)
54+
55+
else:
56+
state_dict = torch.load(self.filepath, map_location="cpu", weights_only=True, mmap=True)
57+
for key in state_dict.keys():
58+
yield key, state_dict[key]
59+
60+
61+
def main(raw_hf_path, merge_hf_path):
62+
torch.set_default_dtype(torch.bfloat16)
63+
os.makedirs(merge_hf_path, exist_ok=True)
64+
65+
config = AutoConfig.from_pretrained(raw_hf_path)
66+
tokenizer = build_tokenizer(raw_hf_path)
67+
68+
safetensor_files = list(glob(os.path.join(raw_hf_path, "*.safetensors")))
69+
safetensor_files.sort()
70+
state_dict_iterators = [StateDictIterator(shard_file) for shard_file in safetensor_files]
71+
new_state_dict = {}
72+
for state_dict_iterator in tqdm(state_dict_iterators, desc="Loading checkpoint shards"):
73+
for name, tensor in state_dict_iterator:
74+
new_state_dict[name] = tensor.cpu()
75+
76+
print(new_state_dict.keys())
77+
78+
if hasattr(config, "num_experts"):
79+
# qwen3moe
80+
num_experts = config.num_experts
81+
elif hasattr(config, "n_routed_experts"):
82+
# deepseek
83+
num_experts = config.n_routed_experts
84+
else:
85+
raise RuntimeError("could not find how many experts to assign")
86+
num_hidden_layers = config.num_hidden_layers
87+
88+
if hasattr(config, "first_k_dense_replace"):
89+
# deepseek first k dense layer
90+
moe_layer_start_idx = config.first_k_dense_replace
91+
else:
92+
# moe layer only in the model
93+
moe_layer_start_idx = 0
94+
95+
for i in range(moe_layer_start_idx, num_hidden_layers):
96+
gate_proj = []
97+
for j in range(num_experts):
98+
gate_proj.append(new_state_dict.pop(f"model.layers.{i}.mlp.experts.{j}.gate_proj.weight"))
99+
100+
new_state_dict[f"model.layers.{i}.mlp.experts.gate_proj"] = torch.stack(gate_proj)
101+
up_proj = []
102+
for j in range(num_experts):
103+
up_proj.append(new_state_dict.pop(f"model.layers.{i}.mlp.experts.{j}.up_proj.weight"))
104+
105+
new_state_dict[f"model.layers.{i}.mlp.experts.up_proj"] = torch.stack(up_proj)
106+
down_proj = []
107+
for j in range(num_experts):
108+
down_proj.append(new_state_dict.pop(f"model.layers.{i}.mlp.experts.{j}.down_proj.weight"))
109+
110+
new_state_dict[f"model.layers.{i}.mlp.experts.down_proj"] = torch.stack(down_proj)
111+
112+
model_assets = [config, tokenizer]
113+
save_model_weights(merge_hf_path, new_state_dict, model_assets=model_assets)
114+
115+
116+
if __name__ == "__main__":
117+
parser = ArgumentParser()
118+
parser.add_argument("--raw_hf_path", type=str, required=True)
119+
parser.add_argument("--merge_hf_path", type=str, required=True)
120+
args = parser.parse_args()
121+
main(args.raw_hf_path, args.merge_hf_path)

scripts/veomni/moe_split.py

Lines changed: 96 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,96 @@
1+
# Copyright 2025 Bytedance Ltd. and/or its affiliates
2+
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
"""
15+
Reverse process of moe_merge.py - splits merged MoE expert weights back to individual experts.
16+
17+
This script takes a HF checkpoint that has been processed by moe_merge.py (where expert weights
18+
are stacked into single tensors) and splits them back to the original format with individual
19+
expert weights.
20+
21+
The process reverses the merging by:
22+
1. Loading stacked tensors like model.layers.{i}.mlp.experts.gate_proj
23+
2. Unstacking them back to individual experts model.layers.{i}.mlp.experts.{j}.gate_proj.weight
24+
3. Handling all three projection types: gate_proj, up_proj, down_proj
25+
26+
Usage: python moe_split.py --merge_hf_path <merged_checkpoint> --split_hf_path <output_dir>
27+
"""
28+
29+
import os
30+
from argparse import ArgumentParser
31+
from dataclasses import dataclass
32+
from glob import glob
33+
from typing import Generator
34+
35+
import torch
36+
from safetensors.torch import safe_open
37+
from tqdm import tqdm
38+
from transformers import AutoConfig
39+
from veomni.models import build_tokenizer, save_model_weights
40+
41+
42+
@dataclass
43+
class StateDictIterator:
44+
filepath: str
45+
46+
def __iter__(self) -> Generator[tuple[str, "torch.Tensor"], None, None]:
47+
if self.filepath.endswith(".safetensors"):
48+
with safe_open(self.filepath, framework="pt", device="cpu") as f:
49+
for key in f.keys():
50+
yield key, f.get_tensor(key)
51+
52+
else:
53+
state_dict = torch.load(self.filepath, map_location="cpu", weights_only=True, mmap=True)
54+
for key in state_dict.keys():
55+
yield key, state_dict[key]
56+
57+
58+
def main(merge_hf_path, split_hf_path):
59+
torch.set_default_dtype(torch.bfloat16)
60+
os.makedirs(split_hf_path, exist_ok=True)
61+
62+
config = AutoConfig.from_pretrained(merge_hf_path)
63+
tokenizer = build_tokenizer(merge_hf_path)
64+
65+
safetensor_files = list(glob(os.path.join(merge_hf_path, "*.safetensors")))
66+
safetensor_files.sort()
67+
state_dict_iterators = [StateDictIterator(shard_file) for shard_file in safetensor_files]
68+
new_state_dict = {}
69+
for state_dict_iterator in tqdm(state_dict_iterators, desc="Loading checkpoint shards"):
70+
for name, tensor in state_dict_iterator:
71+
new_state_dict[name] = tensor.cpu()
72+
73+
num_experts = config.num_experts
74+
num_hidden_layers = config.num_hidden_layers
75+
for i in range(num_hidden_layers):
76+
print(f"Converting layer {i}")
77+
for proj_name in ["gate_proj", "up_proj", "down_proj"]:
78+
stacked_key = f"model.layers.{i}.mlp.experts.{proj_name}"
79+
if stacked_key in new_state_dict:
80+
stacked_tensor = new_state_dict.pop(stacked_key)
81+
for j in range(num_experts):
82+
expert_key = f"model.layers.{i}.mlp.experts.{j}.{proj_name}.weight"
83+
new_state_dict[expert_key] = stacked_tensor[j]
84+
85+
model_assets = [config, tokenizer]
86+
87+
print("Saving to safetensors")
88+
save_model_weights(split_hf_path, new_state_dict, model_assets=model_assets)
89+
90+
91+
if __name__ == "__main__":
92+
parser = ArgumentParser()
93+
parser.add_argument("--merge_hf_path", type=str, required=True)
94+
parser.add_argument("--split_hf_path", type=str, required=True)
95+
args = parser.parse_args()
96+
main(args.merge_hf_path, args.split_hf_path)

verl/trainer/config/engine/veomni.yaml

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,12 @@
11
# Target class for this configuration
22
_target_: verl.workers.config.VeOmniEngineConfig
33

4+
# Whether to offload model parameters to CPU
5+
param_offload: False
6+
7+
# Whether to offload optimizer state to CPU
8+
optimizer_offload: False
9+
410
# fsdp or fsdp2
511
data_parallel_mode: fsdp2
612

verl/utils/veomni_utils.py

Lines changed: 78 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,78 @@
1+
# Copyright 2025 Bytedance Ltd. and/or its affiliates
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
import torch
16+
17+
from verl.utils.device import get_device_id, get_torch_device
18+
19+
20+
@torch.no_grad()
21+
def offload_veomni_model_to_cpu(model, empty_cache: bool = True):
22+
from veomni.distributed.parallel_state import get_parallel_state
23+
24+
assert get_parallel_state().dp_mode == "fsdp2", "Only support fsdp2 offloading for VeOmni model"
25+
26+
model.cpu()
27+
if empty_cache:
28+
get_torch_device().empty_cache()
29+
30+
31+
@torch.no_grad()
32+
def load_veomni_model_to_gpu(model):
33+
from veomni.distributed.parallel_state import get_parallel_state
34+
35+
assert get_parallel_state().dp_mode == "fsdp2", "Only support fsdp2 offloading for VeOmni model"
36+
37+
device = get_device_id()
38+
model.to(device)
39+
40+
41+
@torch.no_grad()
42+
def offload_veomni_optimizer(optimizer):
43+
optimizers = []
44+
# Check if this is a MultiOptimizer (for ep and non-ep parameters when ep+fsdp2 is enabled)
45+
if hasattr(optimizer, "_is_multi_optimizer") and optimizer._is_multi_optimizer:
46+
optimizers.extend(optimizer.optimizers_dict.values())
47+
else:
48+
optimizers.append(optimizer)
49+
50+
for opt in optimizers:
51+
if not opt.state:
52+
continue
53+
for param_group in opt.param_groups:
54+
for param in param_group["params"]:
55+
state = opt.state[param]
56+
for key, value in state.items():
57+
if isinstance(value, torch.Tensor):
58+
state[key] = value.to("cpu", non_blocking=True)
59+
60+
61+
@torch.no_grad()
62+
def load_veomni_optimizer(optimizer, device_id):
63+
optimizers = []
64+
# Check if this is a MultiOptimizer (for ep and non-ep parameters when ep+fsdp2 is enabled)
65+
if hasattr(optimizer, "_is_multi_optimizer") and optimizer._is_multi_optimizer:
66+
optimizers.extend(optimizer.optimizers_dict.values())
67+
else:
68+
optimizers.append(optimizer)
69+
70+
for opt in optimizers:
71+
if not opt.state:
72+
continue
73+
for param_group in opt.param_groups:
74+
for param in param_group["params"]:
75+
state = opt.state[param]
76+
for key, value in state.items():
77+
if isinstance(value, torch.Tensor):
78+
state[key] = value.to(device_id, non_blocking=True)

verl/workers/engine/veomni/transformer_impl.py

Lines changed: 35 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -31,9 +31,15 @@
3131
from verl.trainer.config import CheckpointConfig
3232
from verl.utils import tensordict_utils as tu
3333
from verl.utils.checkpoint.fsdp_checkpoint_manager import FSDPCheckpointManager
34-
from verl.utils.device import get_device_id
34+
from verl.utils.device import get_device_id, get_device_name
3535
from verl.utils.fsdp_utils import fsdp_version
3636
from verl.utils.profiler import log_gpu_memory_usage
37+
from verl.utils.veomni_utils import (
38+
load_veomni_model_to_gpu,
39+
load_veomni_optimizer,
40+
offload_veomni_model_to_cpu,
41+
offload_veomni_optimizer,
42+
)
3743
from verl.workers.config import HFModelConfig, VeOmniEngineConfig, VeOmniOptimizerConfig
3844
from verl.workers.sharding_manager.fsdp_ulysses import FSDPUlyssesShardingManager
3945

@@ -217,6 +223,34 @@ def _build_model_optimizer(self):
217223
self.engine_config.activation_gpu_limit,
218224
)
219225

226+
def to(self, device: str, model: bool = True, optimizer: bool = True, grad: bool = True):
227+
"""
228+
Move model parameters, optimizer states, or both to the specified device.
229+
Note that this function executes irrespective of offload config. It serves as manual control.
230+
231+
Args:
232+
device: Target device identifier.
233+
model: If True, move the model.
234+
optimizer: If True, move the optimizer states.
235+
"""
236+
super(FSDPEngine, self).to(device=device, model=model, optimizer=optimizer, grad=grad)
237+
238+
device_name = get_device_name()
239+
240+
assert device in (device_name, "cpu")
241+
if device == device_name:
242+
if model:
243+
load_veomni_model_to_gpu(self.module)
244+
if optimizer and self.optimizer is not None:
245+
load_veomni_optimizer(self.optimizer, device)
246+
elif device == "cpu":
247+
if model:
248+
offload_veomni_model_to_cpu(self.module)
249+
if optimizer and self.optimizer is not None:
250+
offload_veomni_optimizer(self.optimizer)
251+
else:
252+
raise ValueError(f"Invalid device type: {device}")
253+
220254
def optimizer_step(self):
221255
"""
222256
Perform an optimization step using the optimizer.

0 commit comments

Comments
 (0)