Skip to content

Commit 312a8cb

Browse files
[SGLang] Add support between mcore0.11 and sglang (#1055)
Based on the ongoing alignment between mcore and vllm #851 , I believe we can simultaneously advance the alignment between mcore and sglang, as their interfaces are similar. In the end, we will only need to obtain a generator parameter. [link](sgl-project/sglang#5345)
1 parent 8d36311 commit 312a8cb

File tree

5 files changed

+334
-4
lines changed

5 files changed

+334
-4
lines changed
Lines changed: 48 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,48 @@
1+
set -x
2+
3+
gsm8k_train_path=$HOME/data/gsm8k/train.parquet
4+
gsm8k_test_path=$HOME/data/gsm8k/test.parquet
5+
math_train_path=$HOME/data/math/train.parquet
6+
math_test_path=$HOME/data/math/test.parquet
7+
8+
train_files="['$gsm8k_train_path', '$math_train_path']"
9+
test_files="['$gsm8k_test_path', '$math_test_path']"
10+
11+
python3 -m verl.trainer.main_ppo --config-path=config \
12+
--config-name='ppo_megatron_trainer.yaml'\
13+
algorithm.adv_estimator=grpo \
14+
data.train_files="$train_files" \
15+
data.val_files="$test_files" \
16+
data.train_batch_size=1024 \
17+
data.max_prompt_length=1024 \
18+
data.max_response_length=1024 \
19+
data.filter_overlong_prompts=True \
20+
data.truncation='error' \
21+
actor_rollout_ref.model.path=Qwen/Qwen2-7B-Instruct \
22+
actor_rollout_ref.actor.optim.lr=1e-6 \
23+
actor_rollout_ref.actor.ppo_mini_batch_size=256 \
24+
actor_rollout_ref.actor.ppo_micro_batch_size_per_gpu=4 \
25+
actor_rollout_ref.actor.megatron.pipeline_model_parallel_size=2 \
26+
actor_rollout_ref.actor.megatron.virtual_pipeline_model_parallel_size=2 \
27+
actor_rollout_ref.actor.megatron.tensor_model_parallel_size=4 \
28+
actor_rollout_ref.actor.use_kl_loss=True \
29+
actor_rollout_ref.actor.kl_loss_coef=0.001 \
30+
actor_rollout_ref.actor.kl_loss_type=low_var_kl \
31+
actor_rollout_ref.actor.entropy_coeff=0 \
32+
actor_rollout_ref.model.enable_gradient_checkpointing=True \
33+
actor_rollout_ref.rollout.log_prob_micro_batch_size_per_gpu=4 \
34+
actor_rollout_ref.rollout.tensor_model_parallel_size=2 \
35+
actor_rollout_ref.rollout.name=sglang \
36+
actor_rollout_ref.rollout.gpu_memory_utilization=0.6 \
37+
actor_rollout_ref.rollout.n=5 \
38+
actor_rollout_ref.ref.log_prob_micro_batch_size_per_gpu=4 \
39+
algorithm.use_kl_in_reward=False \
40+
trainer.critic_warmup=0 \
41+
trainer.logger=['console','wandb'] \
42+
trainer.project_name='verl_grpo_example_gsm8k' \
43+
trainer.experiment_name='qwen2_7b_function_rm_megatron' \
44+
trainer.n_gpus_per_node=8 \
45+
trainer.nnodes=1 \
46+
trainer.save_freq=-1 \
47+
trainer.test_freq=5 \
48+
trainer.total_epochs=15 $@

verl/utils/megatron_utils.py

Lines changed: 139 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,8 @@
3232
from transformers import PretrainedConfig
3333

3434
from verl.utils.torch_dtypes import PrecisionType
35+
from verl.utils.model import normalize_model_name
36+
import verl.utils.megatron.tensor_parallel as tp_utils
3537

3638

3739
def get_model_config(model):
@@ -619,3 +621,140 @@ def broadcast_str_from_megatron_pp(obj: Any):
619621
torch.distributed.broadcast_object_list(object_list=obj_output, src=global_rank, group=mpu.get_pipeline_model_parallel_group())
620622

621623
return obj_output[0]
624+
625+
def default_tp_concat_fn(layer_name_mapping, name, train_params, infer_params, model_config, convert_qkv_gate_up_by_simple_split=False):
626+
"""
627+
name: name of the parameter
628+
train_params: training parameters
629+
infer_params (Iterable[torch.Tensor]): a iterator towards list of parameters all-gathered from micro_dp_group
630+
model_config: huggingface model_config
631+
TODO(zhangchi.usc1992): currently, the implementation is adhoc. We can move this function to the model
632+
definition so that it is model-agnostic. If the model doesn't implement this function,
633+
we can throw an error to force user disable TP HybridEngine.
634+
"""
635+
from megatron.core import mpu
636+
637+
if layer_name_mapping.get("qkv_layer_name") in name and "layer_norm" not in name:
638+
# if the tensor is qkv, for each param on tp, split into q, k, v
639+
# concat q, k, v separately.
640+
q_lst = []
641+
k_lst = []
642+
v_lst = []
643+
assert model_config.num_attention_heads % model_config.num_key_value_heads == 0
644+
num_q_per_kv = model_config.num_attention_heads // model_config.num_key_value_heads
645+
assert infer_params[0].shape[0] % (num_q_per_kv + 2) == 0
646+
kv_size_per_tp = infer_params[0].shape[0] // (num_q_per_kv + 2)
647+
split_size = [kv_size_per_tp * num_q_per_kv, kv_size_per_tp, kv_size_per_tp]
648+
for infer_param in infer_params:
649+
num_query_groups_per_partition = model_config.num_key_value_heads // mpu.get_tensor_model_parallel_world_size(
650+
)
651+
for chunk in infer_param.chunk(num_query_groups_per_partition):
652+
split_size = [
653+
kv_size_per_tp * num_q_per_kv // num_query_groups_per_partition,
654+
kv_size_per_tp // num_query_groups_per_partition,
655+
kv_size_per_tp // num_query_groups_per_partition
656+
]
657+
q, k, v = chunk.split(split_size)
658+
q_lst.append(q)
659+
k_lst.append(k)
660+
v_lst.append(v)
661+
q = torch.cat(q_lst, dim=0)
662+
k = torch.cat(k_lst, dim=0)
663+
v = torch.cat(v_lst, dim=0)
664+
if not convert_qkv_gate_up_by_simple_split:
665+
infer_params = torch.cat((q, k, v), dim=0)
666+
else:
667+
infer_params = [q, k, v]
668+
669+
elif layer_name_mapping.get("gate_proj_layer_name") in name:
670+
# if the tensor is gate and proj
671+
gate_lst = []
672+
up_lst = []
673+
for infer_param in infer_params:
674+
gate, up = infer_param.chunk(2)
675+
gate_lst.append(gate)
676+
up_lst.append(up)
677+
gate = torch.cat(gate_lst, dim=0)
678+
up = torch.cat(up_lst, dim=0)
679+
if not convert_qkv_gate_up_by_simple_split:
680+
infer_params = torch.cat((gate, up), dim=0)
681+
else:
682+
infer_params = [gate, up]
683+
684+
else:
685+
# concat tensor
686+
infer_params = torch.cat(infer_params, dim=tp_utils.get_tensor_parallel_partition_dim(train_params))
687+
688+
return infer_params
689+
690+
691+
def per_tensor_generator(actor_module, model_config, weight_converter, layer_name_mapping, convert_qkv_gate_up_by_simple_split=True):
692+
from megatron.core import parallel_state as mpu
693+
pp_rank = mpu.get_pipeline_model_parallel_rank()
694+
pp_size = mpu.get_pipeline_model_parallel_world_size()
695+
vpp_size = len(actor_module)
696+
all_gather_group = mpu.get_tensor_model_parallel_group()
697+
all_gather_group_size = torch.distributed.get_world_size(group=all_gather_group)
698+
699+
def tensor_generator():
700+
for scan_vpp_idx in range(vpp_size):
701+
yield from actor_module[scan_vpp_idx].named_parameters()
702+
703+
# we need first make all rank get full model information
704+
meta_info = []
705+
for scan_vpp_idx in range(vpp_size):
706+
for idx, (name, _) in enumerate(actor_module[scan_vpp_idx].named_parameters()):
707+
meta_info.append((pp_rank, scan_vpp_idx, idx, name))
708+
709+
obj_spec_output = [None] * mpu.get_pipeline_model_parallel_world_size()
710+
torch.distributed.all_gather_object(
711+
object_list=obj_spec_output, obj=meta_info, group=mpu.get_pipeline_model_parallel_group()
712+
)
713+
layer_list_meta = [item for sublist in obj_spec_output for item in sublist]
714+
715+
gen_func = tensor_generator()
716+
717+
# lazy load tensor for full model
718+
for cur_pp_rank, scan_vpp_idx, idx, name in layer_list_meta:
719+
if cur_pp_rank == pp_rank:
720+
try:
721+
cur_name, cur_tensor = next(gen_func)
722+
except StopIteration:
723+
cur_name, cur_tensor = None, None
724+
cur_name = normalize_model_name(
725+
name, cur_pp_rank, scan_vpp_idx, pp_size, vpp_size, model_config.num_hidden_layers
726+
)
727+
else:
728+
cur_tensor, cur_name = None, None
729+
730+
# pp broadcast model tensor and name
731+
cur_name = broadcast_str_from_megatron_pp(cur_name)
732+
broad_pp_tensor = broadcast_from_megatron_pp(cur_tensor)
733+
734+
# (xya): this is a hack to fix the name of the parameters
735+
while cur_name.startswith("module."):
736+
cur_name = cur_name[len("module.") :]
737+
738+
# tp all gather
739+
if tp_utils.is_tensor_parallel_param(broad_pp_tensor):
740+
# allocate a new tensor with proper size
741+
if all_gather_group_size <= 1:
742+
infer_params = [broad_pp_tensor]
743+
else:
744+
infer_params = [torch.empty_like(broad_pp_tensor) for _ in range(all_gather_group_size)]
745+
torch.distributed.all_gather(
746+
infer_params, broad_pp_tensor, group=mpu.get_tensor_model_parallel_group()
747+
)
748+
infer_params = default_tp_concat_fn(
749+
layer_name_mapping, cur_name, broad_pp_tensor, infer_params, model_config, convert_qkv_gate_up_by_simple_split
750+
)
751+
else:
752+
infer_params = broad_pp_tensor
753+
754+
755+
if not isinstance(infer_params, list):
756+
infer_params = [infer_params]
757+
converted_names, converted_params = weight_converter.convert_param(cur_name, infer_params)
758+
759+
yield from zip(converted_names, converted_params)
760+

verl/workers/megatron_workers.py

Lines changed: 29 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -209,6 +209,10 @@ def megatron_actor_model_provider(pre_process, post_process):
209209
return actor_module, actor_optimizer, self.hf_config, optim_config
210210

211211
def _build_rollout(self, trust_remote_code=False):
212+
layer_name_mapping = {
213+
"qkv_layer_name": "self_attention.linear_qkv.",
214+
"gate_proj_layer_name": "linear_fc1.weight",
215+
}
212216
if self.config.rollout.name == "vllm":
213217
from torch.distributed.device_mesh import init_device_mesh
214218

@@ -217,10 +221,7 @@ def _build_rollout(self, trust_remote_code=False):
217221

218222
# NOTE(sgm): If the QKV and gate_up projection layer are concate together in actor,
219223
# we will reorganize their weight format when resharding from actor to rollout.
220-
layer_name_mapping = {
221-
"qkv_layer_name": "self_attention.linear_qkv.",
222-
"gate_proj_layer_name": "linear_fc1.weight",
223-
}
224+
224225

225226
infer_tp = self.config.rollout.tensor_model_parallel_size
226227
dp = self.world_size // infer_tp
@@ -259,6 +260,30 @@ def _build_rollout(self, trust_remote_code=False):
259260
weight_converter=weight_converter,
260261
)
261262
log_gpu_memory_usage("After building sharding manager", logger=logger)
263+
elif self.config.rollout.name == 'sglang':
264+
from verl.workers.rollout.sglang_rollout import SGLangRollout
265+
# NOTE(linjunrong): Due to recent fp8 support in SGLang. Now importing any symbol relate to SGLang's model_runner would check CUDA device capability.
266+
# However, due to veRL's setting, the main process of ray can not find any CUDA device, which would potentially lead to:
267+
# "RuntimeError: No CUDA GPUs are available".
268+
# For this reason, sharding_manager.__init__ should not import FSDPSGLangShardingManager and we import it here use the abs path.
269+
# check: https://github.com/sgl-project/sglang/blob/00f42707eaddfc2c0528e5b1e0094025c640b7a0/python/sglang/srt/layers/quantization/fp8_utils.py#L76
270+
from verl.workers.sharding_manager.megatron_sglang import MegatronSGLangShardingManager
271+
local_path = copy_to_local(self.config.model.path)
272+
log_gpu_memory_usage(f'Before building {self.config.rollout.name} rollout', logger=None)
273+
rollout = SGLangRollout(actor_module=local_path,
274+
config=self.config.rollout,
275+
tokenizer=self.tokenizer,
276+
model_hf_config=self.actor_model_config)
277+
log_gpu_memory_usage(f'After building {self.config.rollout.name} rollout', logger=None)
278+
279+
from verl.models.mcore import get_mcore_weight_converter
280+
weight_converter = get_mcore_weight_converter(self.actor_model_config, self.dtype)
281+
sharding_manager = MegatronSGLangShardingManager(actor_module=self.actor.actor_module,
282+
inference_engine=rollout.inference_engine,
283+
model_config=self.actor_model_config,
284+
layer_name_mapping=layer_name_mapping,
285+
weight_converter=weight_converter,)
286+
log_gpu_memory_usage('After building sharding manager', logger=logger)
262287
else:
263288
raise NotImplementedError("Only vllmRollout is supported with Megatron now")
264289

verl/workers/rollout/sglang_rollout/sglang_rollout.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -370,3 +370,12 @@ def generate_sequences(self, prompts: DataProto, **kwargs) -> DataProto:
370370
self.inference_engine._engine.flush_cache()
371371

372372
return DataProto(batch=batch, non_tensor_batch=non_tensor_batch)
373+
374+
# this function is left for uniform train-inference resharding
375+
def update_weights(self, params_iter):
376+
self.inference_engine.resume_memory_occupation()
377+
self.inference_engine.update_weights_from_tensor(params_iter, load_format=None)
378+
379+
# this function is left for uniform train-inference resharding
380+
def offload(self):
381+
self.inference_engine.release_memory_occupation()
Lines changed: 109 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,109 @@
1+
# Copyright 2024 Bytedance Ltd. and/or its affiliates
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
"""
15+
This file contains a Megatron style Hybrid Engine that shares the weights of the actor with the inference engine.
16+
"""
17+
18+
import importlib
19+
import logging
20+
import os
21+
import torch
22+
import torch.distributed as dist
23+
from torch import nn
24+
25+
from verl.utils.model import normalize_model_name
26+
from verl.utils.megatron_utils import broadcast_from_megatron_pp, broadcast_str_from_megatron_pp
27+
28+
from verl.utils.megatron_utils import get_model, unwrap_model
29+
from verl.utils.debug import log_gpu_memory_usage
30+
from verl.utils.megatron_utils import convert_megatron_model_to_transformers_model
31+
32+
logger = logging.getLogger(__file__)
33+
logger.setLevel(os.getenv('VERL_PPO_LOGGING_LEVEL', 'WARN'))
34+
"""
35+
Megatron Hybrid Engine:
36+
- During training, only the current pp stage holds the parameters
37+
- Before inference, broadcast the parameters of the current pp rank to all other pp ranks (all pp ranks holds all the parameters)
38+
- Bind the parameters to the inference engine
39+
- Do inference in tp. pp is treated as additional dp
40+
- After inference, all the parameters that doesn't belong to this pp rank is freed.
41+
"""
42+
43+
from .base import BaseShardingManager
44+
45+
import torch
46+
from torch import nn
47+
import torch.distributed
48+
from torch.distributed import new_group
49+
from torch.distributed._tensor import DTensor
50+
from typing import Dict, Iterable, Union, Tuple
51+
52+
from verl import DataProto
53+
from verl.protocol import all_gather_data_proto
54+
from verl.utils.torch_functional import (broadcast_dict_tensor, allgather_dict_tensors)
55+
from sglang.srt.entrypoints.verl_engine import VerlEngine
56+
from verl.utils.debug import GPUMemoryLogger
57+
58+
import verl.utils.megatron.tensor_parallel as tp_utils
59+
from verl.utils.megatron_utils import per_tensor_generator, default_tp_concat_fn
60+
61+
_MICRO_DATA_PARALLEL_GROUP = None
62+
63+
64+
class MegatronSGLangShardingManager(BaseShardingManager):
65+
66+
def __init__(self, actor_module: nn.ModuleList, inference_engine: VerlEngine, model_config, layer_name_mapping, weight_converter):
67+
from megatron.core import parallel_state as mpu
68+
self.actor_module = actor_module
69+
self.inference_engine = inference_engine
70+
self.model_config = model_config
71+
self.layer_name_mapping = layer_name_mapping
72+
self.weight_converter = weight_converter
73+
global _MICRO_DATA_PARALLEL_GROUP
74+
world_size = torch.distributed.get_world_size()
75+
rank = torch.distributed.get_rank()
76+
77+
self.infer_tp_size = self.inference_engine._tp_size
78+
self.train_tp_size = mpu.get_tensor_model_parallel_world_size()
79+
self.need_tp_reshard = self.infer_tp_size == self.train_tp_size
80+
81+
assert self.infer_tp_size <= self.train_tp_size, \
82+
'Not implemented for infer_tp > train_tp'
83+
assert self.train_tp_size % self.infer_tp_size == 0
84+
85+
micro_dp_size = self.train_tp_size // self.infer_tp_size
86+
num_micro_dp_groups = world_size // micro_dp_size
87+
assert _MICRO_DATA_PARALLEL_GROUP is None, ("micro data parallel group is already initialized")
88+
for i in range(num_micro_dp_groups):
89+
ranks = range(i * micro_dp_size, (i + 1) * micro_dp_size)
90+
group = new_group(ranks=ranks)
91+
if rank in ranks:
92+
_MICRO_DATA_PARALLEL_GROUP = group
93+
94+
@GPUMemoryLogger(role="MegatronSGLangShardingManager enter", logger=logger)
95+
def __enter__(self):
96+
per_tensor_param = per_tensor_generator(self.actor_module, self.model_config, self.weight_converter, self.layer_name_mapping)
97+
self.inference_engine.resume_memory_occupation()
98+
self.inference_engine.update_weights_from_tensor(per_tensor_param, load_format=None)
99+
100+
@GPUMemoryLogger(role="MegatronSGLangShardingManager exit", logger=logger)
101+
def __exit__(self, exc_type, exc_value, traceback):
102+
log_gpu_memory_usage('Before SGLang offload in sharding manager', logger=logger)
103+
self.inference_engine.release_memory_occupation()
104+
log_gpu_memory_usage('After SGLang offload in sharding manager', logger=logger)
105+
106+
for model in self.actor_module:
107+
model.train()
108+
# add empty cache after each compute
109+
torch.cuda.empty_cache()

0 commit comments

Comments
 (0)