Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
121 changes: 121 additions & 0 deletions scripts/veomni/moe_merge.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,121 @@
# Copyright 2025 Bytedance Ltd. and/or its affiliates

# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Merge individual MoE expert weights into stacked tensors for efficient loading.

This script takes a HuggingFace checkpoint with individual expert weights
(e.g., model.layers.{i}.mlp.experts.{j}.gate_proj.weight) and merges them
into stacked tensors (e.g., model.layers.{i}.mlp.experts.gate_proj) for
faster loading and better memory efficiency in VeOmni.

The merging process:
1. Loads individual expert weights from the HF checkpoint
2. Stacks them into single tensors for each projection type
3. Handles all three projection types: gate_proj, up_proj, down_proj
4. Supports both Qwen3-MoE (num_experts) and DeepSeek (n_routed_experts) formats
5. Handles models with initial dense layers (first_k_dense_replace)

Usage: python moe_merge.py --raw_hf_path <input_checkpoint> --merge_hf_path <output_dir>
"""

import os
from argparse import ArgumentParser
from dataclasses import dataclass
from glob import glob
from typing import Generator

import torch
from safetensors.torch import safe_open
from tqdm import tqdm
from transformers import AutoConfig
from veomni.models import build_tokenizer, save_model_weights


@dataclass
class StateDictIterator:
filepath: str

def __iter__(self) -> Generator[tuple[str, "torch.Tensor"], None, None]:
if self.filepath.endswith(".safetensors"):
with safe_open(self.filepath, framework="pt", device="cpu") as f:
for key in f.keys():
yield key, f.get_tensor(key)

else:
state_dict = torch.load(self.filepath, map_location="cpu", weights_only=True, mmap=True)
for key in state_dict.keys():
yield key, state_dict[key]


def main(raw_hf_path, merge_hf_path):
torch.set_default_dtype(torch.bfloat16)
os.makedirs(merge_hf_path, exist_ok=True)

config = AutoConfig.from_pretrained(raw_hf_path)
tokenizer = build_tokenizer(raw_hf_path)

safetensor_files = list(glob(os.path.join(raw_hf_path, "*.safetensors")))
safetensor_files.sort()
state_dict_iterators = [StateDictIterator(shard_file) for shard_file in safetensor_files]
new_state_dict = {}
for state_dict_iterator in tqdm(state_dict_iterators, desc="Loading checkpoint shards"):
for name, tensor in state_dict_iterator:
new_state_dict[name] = tensor.cpu()

print(new_state_dict.keys())

if hasattr(config, "num_experts"):
# qwen3moe
num_experts = config.num_experts
elif hasattr(config, "n_routed_experts"):
# deepseek
num_experts = config.n_routed_experts
else:
raise RuntimeError("could not find how many experts to assign")
num_hidden_layers = config.num_hidden_layers

if hasattr(config, "first_k_dense_replace"):
# deepseek first k dense layer
moe_layer_start_idx = config.first_k_dense_replace
else:
# moe layer only in the model
moe_layer_start_idx = 0

for i in range(moe_layer_start_idx, num_hidden_layers):
gate_proj = []
for j in range(num_experts):
gate_proj.append(new_state_dict.pop(f"model.layers.{i}.mlp.experts.{j}.gate_proj.weight"))

new_state_dict[f"model.layers.{i}.mlp.experts.gate_proj"] = torch.stack(gate_proj)
up_proj = []
for j in range(num_experts):
up_proj.append(new_state_dict.pop(f"model.layers.{i}.mlp.experts.{j}.up_proj.weight"))

new_state_dict[f"model.layers.{i}.mlp.experts.up_proj"] = torch.stack(up_proj)
down_proj = []
for j in range(num_experts):
down_proj.append(new_state_dict.pop(f"model.layers.{i}.mlp.experts.{j}.down_proj.weight"))

new_state_dict[f"model.layers.{i}.mlp.experts.down_proj"] = torch.stack(down_proj)

model_assets = [config, tokenizer]
save_model_weights(merge_hf_path, new_state_dict, model_assets=model_assets)


if __name__ == "__main__":
parser = ArgumentParser()
parser.add_argument("--raw_hf_path", type=str, required=True)
parser.add_argument("--merge_hf_path", type=str, required=True)
args = parser.parse_args()
main(args.raw_hf_path, args.merge_hf_path)
96 changes: 96 additions & 0 deletions scripts/veomni/moe_split.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,96 @@
# Copyright 2025 Bytedance Ltd. and/or its affiliates

# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Reverse process of moe_merge.py - splits merged MoE expert weights back to individual experts.

This script takes a HF checkpoint that has been processed by moe_merge.py (where expert weights
are stacked into single tensors) and splits them back to the original format with individual
expert weights.

The process reverses the merging by:
1. Loading stacked tensors like model.layers.{i}.mlp.experts.gate_proj
2. Unstacking them back to individual experts model.layers.{i}.mlp.experts.{j}.gate_proj.weight
3. Handling all three projection types: gate_proj, up_proj, down_proj

Usage: python moe_split.py --merge_hf_path <merged_checkpoint> --split_hf_path <output_dir>
"""

import os
from argparse import ArgumentParser
from dataclasses import dataclass
from glob import glob
from typing import Generator

import torch
from safetensors.torch import safe_open
from tqdm import tqdm
from transformers import AutoConfig
from veomni.models import build_tokenizer, save_model_weights


@dataclass
class StateDictIterator:
filepath: str

def __iter__(self) -> Generator[tuple[str, "torch.Tensor"], None, None]:
if self.filepath.endswith(".safetensors"):
with safe_open(self.filepath, framework="pt", device="cpu") as f:
for key in f.keys():
yield key, f.get_tensor(key)

else:
state_dict = torch.load(self.filepath, map_location="cpu", weights_only=True, mmap=True)
for key in state_dict.keys():
yield key, state_dict[key]


def main(merge_hf_path, split_hf_path):
torch.set_default_dtype(torch.bfloat16)
os.makedirs(split_hf_path, exist_ok=True)

config = AutoConfig.from_pretrained(merge_hf_path)
tokenizer = build_tokenizer(merge_hf_path)

safetensor_files = list(glob(os.path.join(merge_hf_path, "*.safetensors")))
safetensor_files.sort()
state_dict_iterators = [StateDictIterator(shard_file) for shard_file in safetensor_files]
new_state_dict = {}
for state_dict_iterator in tqdm(state_dict_iterators, desc="Loading checkpoint shards"):
for name, tensor in state_dict_iterator:
new_state_dict[name] = tensor.cpu()

num_experts = config.num_experts
num_hidden_layers = config.num_hidden_layers
for i in range(num_hidden_layers):
print(f"Converting layer {i}")
for proj_name in ["gate_proj", "up_proj", "down_proj"]:
stacked_key = f"model.layers.{i}.mlp.experts.{proj_name}"
if stacked_key in new_state_dict:
stacked_tensor = new_state_dict.pop(stacked_key)
for j in range(num_experts):
expert_key = f"model.layers.{i}.mlp.experts.{j}.{proj_name}.weight"
new_state_dict[expert_key] = stacked_tensor[j]

model_assets = [config, tokenizer]

print("Saving to safetensors")
save_model_weights(split_hf_path, new_state_dict, model_assets=model_assets)


if __name__ == "__main__":
parser = ArgumentParser()
parser.add_argument("--merge_hf_path", type=str, required=True)
parser.add_argument("--split_hf_path", type=str, required=True)
args = parser.parse_args()
main(args.merge_hf_path, args.split_hf_path)
6 changes: 6 additions & 0 deletions verl/trainer/config/engine/veomni.yaml
Original file line number Diff line number Diff line change
@@ -1,6 +1,12 @@
# Target class for this configuration
_target_: verl.workers.config.VeOmniEngineConfig

# Whether to offload model parameters to CPU
param_offload: False

# Whether to offload optimizer state to CPU
optimizer_offload: False

# fsdp or fsdp2
data_parallel_mode: fsdp2

Expand Down
78 changes: 78 additions & 0 deletions verl/utils/veomni_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,78 @@
# Copyright 2025 Bytedance Ltd. and/or its affiliates
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import torch

from verl.utils.device import get_device_id, get_torch_device


@torch.no_grad()
def offload_veomni_model_to_cpu(model, empty_cache: bool = True):
from veomni.distributed.parallel_state import get_parallel_state

assert get_parallel_state().dp_mode == "fsdp2", "Only support fsdp2 offloading for VeOmni model"

model.cpu()
if empty_cache:
get_torch_device().empty_cache()


@torch.no_grad()
def load_veomni_model_to_gpu(model):
from veomni.distributed.parallel_state import get_parallel_state

assert get_parallel_state().dp_mode == "fsdp2", "Only support fsdp2 offloading for VeOmni model"

device = get_device_id()
model.to(device)


@torch.no_grad()
def offload_veomni_optimizer(optimizer):
optimizers = []
# Check if this is a MultiOptimizer (for ep and non-ep parameters when ep+fsdp2 is enabled)
if hasattr(optimizer, "_is_multi_optimizer") and optimizer._is_multi_optimizer:
optimizers.extend(optimizer.optimizers_dict.values())
else:
optimizers.append(optimizer)

for opt in optimizers:
if not opt.state:
continue
for param_group in opt.param_groups:
for param in param_group["params"]:
state = opt.state[param]
for key, value in state.items():
if isinstance(value, torch.Tensor):
state[key] = value.to("cpu", non_blocking=True)


@torch.no_grad()
def load_veomni_optimizer(optimizer, device_id):
optimizers = []
# Check if this is a MultiOptimizer (for ep and non-ep parameters when ep+fsdp2 is enabled)
if hasattr(optimizer, "_is_multi_optimizer") and optimizer._is_multi_optimizer:
optimizers.extend(optimizer.optimizers_dict.values())
else:
optimizers.append(optimizer)

for opt in optimizers:
if not opt.state:
continue
for param_group in opt.param_groups:
for param in param_group["params"]:
state = opt.state[param]
for key, value in state.items():
if isinstance(value, torch.Tensor):
state[key] = value.to(device_id, non_blocking=True)
36 changes: 35 additions & 1 deletion verl/workers/engine/veomni/transformer_impl.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,9 +31,15 @@
from verl.trainer.config import CheckpointConfig
from verl.utils import tensordict_utils as tu
from verl.utils.checkpoint.fsdp_checkpoint_manager import FSDPCheckpointManager
from verl.utils.device import get_device_id
from verl.utils.device import get_device_id, get_device_name
from verl.utils.fsdp_utils import fsdp_version
from verl.utils.profiler import log_gpu_memory_usage
from verl.utils.veomni_utils import (
load_veomni_model_to_gpu,
load_veomni_optimizer,
offload_veomni_model_to_cpu,
offload_veomni_optimizer,
)
from verl.workers.config import HFModelConfig, VeOmniEngineConfig, VeOmniOptimizerConfig
from verl.workers.sharding_manager.fsdp_ulysses import FSDPUlyssesShardingManager

Expand Down Expand Up @@ -217,6 +223,34 @@ def _build_model_optimizer(self):
self.engine_config.activation_gpu_limit,
)

def to(self, device: str, model: bool = True, optimizer: bool = True, grad: bool = True):
"""
Move model parameters, optimizer states, or both to the specified device.
Note that this function executes irrespective of offload config. It serves as manual control.

Args:
device: Target device identifier.
model: If True, move the model.
optimizer: If True, move the optimizer states.
"""
super(FSDPEngine, self).to(device=device, model=model, optimizer=optimizer, grad=grad)

device_name = get_device_name()

assert device in (device_name, "cpu")
if device == device_name:
if model:
load_veomni_model_to_gpu(self.module)
if optimizer and self.optimizer is not None:
load_veomni_optimizer(self.optimizer, device)
elif device == "cpu":
if model:
offload_veomni_model_to_cpu(self.module)
if optimizer and self.optimizer is not None:
offload_veomni_optimizer(self.optimizer)
else:
raise ValueError(f"Invalid device type: {device}")

def optimizer_step(self):
"""
Perform an optimization step using the optimizer.
Expand Down