Skip to content

Commit b7af31a

Browse files
support qwen3-next (#385)
* extract BaseMegatronMapper * rename MegatronMapper * make update_mapping only called in '_inner_map_*' * extract Megatron Mapper for VLM * clean docstring * fix pylint * fix pylint * remove src_arch * make some mapping functions be fully configurable * init commit * test qwen2_5_vl * fix pylint * fix issues when PP > 1 * passing a copy to avoid inplace modification on fp32 logits * fix issue * add draft version of qwen3-vl * add how to build image for qwen3-next * fix param_sync * Add SGLANG PATCH to README * fix readme and scripts * fix memory_pool.py overwrite in readme * demo * demo update * fix wandb logging * demo update * fix convergence issue * update readme * fix pylint --------- Co-authored-by: Peng Li <[email protected]>
1 parent f5f86e8 commit b7af31a

26 files changed

+1711
-802
lines changed

chatlearn/algorithm/grpo_utils/megatron_policy_trainer.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -249,7 +249,7 @@ def train_step(self, data_list: List[Dict[str, Any]], **kwargs):
249249
num_zeros_in_grad,
250250
self.stats,
251251
{},
252-
"policy_trainer",
252+
"",
253253
self._metric_list,
254254
)
255255

chatlearn/algorithm/grpo_utils/megatron_utils/train_helper.py

Lines changed: 12 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -113,32 +113,32 @@ def training_log(
113113
if is_last_rank():
114114

115115
for key in loss_dict:
116-
iter_dict[f"{name}/{key}"] = loss_dict[key]
117-
consumed_train_samples_dict[f"{name}/" + key + " vs samples"] = loss_dict[
116+
iter_dict[f"{key}"] = loss_dict[key]
117+
consumed_train_samples_dict[key + " vs samples"] = loss_dict[
118118
key
119119
]
120120

121121
if grad_norm is not None:
122-
iter_dict[f"{name}/" + "grad_norm"] = grad_norm
123-
consumed_train_samples_dict[f"{name}/" + "grad-norm vs samples"] = grad_norm
122+
iter_dict["grad_norm"] = grad_norm
123+
consumed_train_samples_dict["grad-norm vs samples"] = grad_norm
124124

125125
if more_grad_norm is not None:
126126
for k in more_grad_norm:
127-
iter_dict[f"{name}/{k}" + " grad_norm"] = more_grad_norm[k]
128-
consumed_train_samples_dict[f"{name}/{k}" + " grad-norm vs samples"] = (
127+
iter_dict[f"{k}" + " grad_norm"] = more_grad_norm[k]
128+
consumed_train_samples_dict[f"{k}" + " grad-norm vs samples"] = (
129129
more_grad_norm[k]
130130
)
131131

132132
if params_norm is not None:
133-
iter_dict[f"{name}/" + "params-norm"] = params_norm
134-
consumed_train_samples_dict[f"{name}/" + "params-norm vs samples"] = (
133+
iter_dict["params-norm"] = params_norm
134+
consumed_train_samples_dict["params-norm vs samples"] = (
135135
params_norm
136136
)
137137

138138
elapsed_time = 0
139139
elapsed_time_per_iteration = elapsed_time / total_iterations
140140
if args.log_timers_to_tensorboard:
141-
iter_dict[f"{name}/" + "iteration-time"] = elapsed_time_per_iteration
141+
iter_dict["iteration-time"] = elapsed_time_per_iteration
142142

143143
log_string = " iteration {:8d}/infinity |".format(iteration)
144144
log_string += " consumed samples: {:12d} |".format(args.consumed_train_samples)
@@ -560,9 +560,11 @@ def forward_step(data_iterator, model, *, is_training: bool=False, is_packing: b
560560
'input_ids': inputs["all_tokens"],
561561
'position_ids': inputs["all_token_position_ids"],
562562
'labels': inputs["labels"] if not is_training else None,
563-
'packed_seq_params': inputs['packed_seq_params'] if is_packing else None
564563
}
565564

565+
if is_packing:
566+
kwargs.update({'packed_seq_params': inputs['packed_seq_params']})
567+
566568
if 'pixel_values' in inputs:
567569
kwargs.update({
568570
'vision_data': inputs["pixel_values"],

chatlearn/configs/megatron_config.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -70,6 +70,7 @@ class MegatronModelArchitectureConfig(BaseConfig):
7070
default=1000000,
7171
metadata={"help": "Base to use for rotary positional embeddings"},
7272
)
73+
rotary_percent: float = 1.0
7374
group_query_attention: bool = field(
7475
default=False, metadata={"help": "Use group-query attention."}
7576
)
@@ -245,6 +246,11 @@ class MegatronModelArchitectureConfig(BaseConfig):
245246
freeze_VP: bool = field(
246247
default=False, metadata={"help": "Freeze vision projection layers"}
247248
)
249+
250+
hybrid_override_pattern: Optional[str] = None
251+
is_hybrid_model: bool = False
252+
apply_layernorm_1p: bool = False
253+
248254
def _post_init_impl(self):
249255
if self.moe_aux_loss_coeff == 0:
250256
self.moe_router_load_balancing_type = 'none'
@@ -329,6 +335,12 @@ class MegatronConfig(BaseConfig):
329335
}
330336
)
331337

338+
use_expandable_segments: bool = field(
339+
default=False, metadata={"help": "Whether to use expandable_segments in PYTORCH_CUDA_ALLOC_CONF, \
340+
avoid big reseverd memory in ref and policy trainer worker, expandable_segments should be False \
341+
while in parameter sync for efficiency"}
342+
)
343+
332344
def _validate_impl(self):
333345
assert self.num_gpu > 0, "Megatron-Core requires at least one GPU"
334346
assert self.num_gpu % self.num_replica == 0, \
@@ -443,6 +455,7 @@ class MegatronPolicyTrainerConfig(PolicyTrainerConfig, MegatronConfig):
443455
"help": "Load model for finetuning. Do not load optimizer or rng state from checkpoint and set iteration to 0."
444456
},
445457
)
458+
distributed_timeout_minutes: int = 10
446459

447460
def _validate_impl(self):
448461
assert self.calculate_per_token_loss, "Per-Token-Loss is required for Training."

chatlearn/models/megatron_module.py

Lines changed: 4 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,6 @@
1616
import re
1717
from dataclasses import fields
1818

19-
import inspect
2019
import torch
2120

2221
try:
@@ -123,6 +122,8 @@ def model_setup(self):
123122
"""
124123
:meta private:
125124
"""
125+
if self.module_args.use_expandable_segments:
126+
torch.cuda.memory._set_allocator_settings("expandable_segments:True")
126127
super().model_setup()
127128

128129
# TODO: we may need to let setup return model, optimizer and opt_param_scheduler
@@ -255,17 +256,10 @@ def map_local_param_name_to_global(self):
255256
self.global_name_to_local_name = {}
256257
# NOTE: this regex is for model with TEGroupedGEMM
257258
# SequentialMLP or GroupedMLP is not supported
258-
regex = re.compile(r"(.*)decoder.layers\.(\d+)\.([a-z0-9_.]+)([\._])([a-z]+)([0-9]*)")
259+
regex = re.compile(r"(.*)decoder.layers\.(\d+)\.([a-zA-Z0-9_.]+)([\._])([a-zA-Z]+)([0-9]*)")
259260
for vp_stage, model_chunk in enumerate(self.model):
260261
model_config = unwrap_model(model_chunk).config
261-
if 'vp_stage' in inspect.signature(get_transformer_layer_offset).parameters:
262-
offset = get_transformer_layer_offset(model_config, vp_stage=vp_stage)
263-
else:
264-
if len(self.model) > 1:
265-
mpu.set_virtual_pipeline_model_parallel_rank(vp_stage)
266-
offset = get_transformer_layer_offset(model_config)
267-
if len(self.model) > 1:
268-
mpu.set_virtual_pipeline_model_parallel_rank(None)
262+
offset = get_transformer_layer_offset(model_config, vp_stage=vp_stage)
269263
if model_config.num_moe_experts is not None:
270264
ep_rank = mpu.get_expert_model_parallel_rank()
271265
ep_size = mpu.get_expert_model_parallel_world_size()

chatlearn/models/sglang_module.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -412,6 +412,12 @@ def generate(self, query: List[Dict], is_eval: bool) -> List[Dict]:
412412
self.flush_cache()
413413
return outputs
414414

415+
def dump_parameters(self, dump_path_root):
416+
os.makedirs(dump_path_root, exist_ok=True)
417+
self.onload()
418+
self.llm.save_sharded_model(path=dump_path_root, pattern=None, max_size=None)
419+
self.offload()
420+
415421
def update_weights_from_ipc_handles(self, reduce_data):
416422
gathered_data = None
417423
if self.is_engine():
@@ -729,6 +735,12 @@ async def generate(self, query: List[Dict], is_eval: bool) -> List[Dict]:
729735
)
730736
return outputs
731737

738+
async def dump_parameters(self, dump_path_root):
739+
os.makedirs(dump_path_root, exist_ok=True)
740+
await self.onload()
741+
self.llm.save_sharded_model(path=dump_path_root, pattern=None, max_size=None)
742+
await self.offload()
743+
732744
async def generate_per_request(self, query: Dict, is_eval: bool) -> Dict:
733745
outputs = None
734746
if self.is_engine():

chatlearn/runtime/engine.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -556,7 +556,7 @@ def _resume_from_data_checkpoint(self):
556556
def dump_parameters(self, dump_path):
557557
for _, model in enumerate(self.models):
558558
replic_0 = model.replicas[0]
559-
if isinstance(replic_0, DistVLLMActor):
559+
if isinstance(replic_0, (DistVLLMActor, DistSGLangActor)):
560560
future.wait(replic_0.engine.dump_parameters.remote(dump_path))
561561

562562
def save_checkpoint(self, episode_id):

chatlearn/synchronizer/mappers/__init__.py

Lines changed: 20 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -22,20 +22,30 @@
2222
def get_mapper_name(src_model: 'DistModel', dst_model: 'DistModel'):
2323
src_type = src_model.runtime_args.train_backend
2424
dst_type = dst_model.runtime_args.rollout_backend
25-
if src_type == 'megatron' and dst_type == 'vllm':
26-
return "MegatronVLLMMapper"
27-
elif src_type == 'megatron' and dst_type == 'sglang':
28-
return "MegatronSGLangMapper"
29-
else:
30-
raise NotImplementedError(f"Unsupported src/dst model combination: {src_type}-{dst_type}")
25+
model_type = src_model.runtime_args.model_type # llm or vlm
26+
27+
mapping = {
28+
'llm-megatron-vllm': "MegatronVLLMMapper-LLM",
29+
'llm-megatron-sglang': "MegatronSGLangMapper-LLM",
30+
'vlm-megatron-vllm': "MegatronVLLMMapper-VLM",
31+
'vlm-megatron-sglang': "MegatronSGLangMapper-VLM",
32+
}
33+
key = f'{model_type}-{src_type}-{dst_type}'
34+
if key not in mapping:
35+
raise NotImplementedError(f"Unsupported src/dst model combination: {key}")
36+
return mapping[key]
3137

3238

3339
def name_to_mapper_cls(mapper_name: str):
3440
# pylint: disable=import-outside-toplevel
3541
from .mapping_helpers import VLLM_HELPERS, HF_HELPERS
36-
if mapper_name in ["MegatronVLLMMapper", "MegatronSGLangMapper"]:
37-
from .mapper import MegatronMapper
38-
helper_mappings = {"MegatronVLLMMapper": VLLM_HELPERS, "MegatronSGLangMapper": HF_HELPERS}
39-
return partial(MegatronMapper, mapper_config=helper_mappings[mapper_name])
42+
if mapper_name in ["MegatronVLLMMapper-LLM", "MegatronSGLangMapper-LLM"]:
43+
from .megatron_llm_mapper import MegatronLLMMapper
44+
helper_mappings = {"MegatronVLLMMapper-LLM": VLLM_HELPERS, "MegatronSGLangMapper-LLM": HF_HELPERS}
45+
return partial(MegatronLLMMapper, mapper_config=helper_mappings[mapper_name])
46+
elif mapper_name in ["MegatronVLLMMapper-VLM", "MegatronSGLangMapper-VLM"]:
47+
from .megatron_vlm_mapper import MegatronVLMMapper
48+
helper_mappings = {"MegatronVLLMMapper-VLM": VLLM_HELPERS, "MegatronSGLangMapper-VLM": HF_HELPERS}
49+
return partial(MegatronVLMMapper, mapper_config=helper_mappings[mapper_name])
4050
else:
4151
raise ValueError(f"Unrecognized Mapper {mapper_name}")

0 commit comments

Comments
 (0)