[model] feat: support qwen35 mtp sft/rl#5898
[model] feat: support qwen35 mtp sft/rl#5898zpltys wants to merge 2 commits intoverl-project:mainfrom
Conversation
There was a problem hiding this comment.
Code Review
This pull request introduces support for Qwen3.5 MoE models with Multi-Token Prediction (MTP) for Megatron-based SFT and GRPO training. Changes include new example scripts, model registry updates, and enhanced configuration conversion logic to handle Qwen3.5-specific MTP attributes. Review feedback recommends making the internal MTP configuration helpers public for better cross-module reuse, addressing a potential bug where user-defined MTP loss scaling factors are ignored during configuration conversion, and eliminating redundant logic in the Megatron workers by consistently using the new helper functions.
| def _get_mtp_num_layers(hf_config): | ||
| """Get MTP layer count from various config formats. | ||
|
|
||
| Supports: | ||
| - num_nextn_predict_layers (DeepSeek, Qwen3 style) | ||
| - mtp_num_hidden_layers (Qwen3.5 style, in hf_config or text_config) | ||
| """ | ||
| if hasattr(hf_config, "num_nextn_predict_layers") and hf_config.num_nextn_predict_layers > 0: | ||
| return hf_config.num_nextn_predict_layers | ||
| if hasattr(hf_config, "mtp_num_hidden_layers") and hf_config.mtp_num_hidden_layers > 0: | ||
| return hf_config.mtp_num_hidden_layers | ||
| if hasattr(hf_config, "text_config") and hasattr(hf_config.text_config, "mtp_num_hidden_layers"): | ||
| if hf_config.text_config.mtp_num_hidden_layers > 0: | ||
| return hf_config.text_config.mtp_num_hidden_layers | ||
| return 0 | ||
|
|
||
|
|
||
| def _set_mtp_num_layers(hf_config, value: int): | ||
| """Set MTP layer count in the appropriate config field.""" | ||
| if hasattr(hf_config, "num_nextn_predict_layers"): | ||
| hf_config.num_nextn_predict_layers = value | ||
| elif hasattr(hf_config, "mtp_num_hidden_layers"): | ||
| hf_config.mtp_num_hidden_layers = value | ||
| elif hasattr(hf_config, "text_config") and hasattr(hf_config.text_config, "mtp_num_hidden_layers"): | ||
| hf_config.text_config.mtp_num_hidden_layers = value |
There was a problem hiding this comment.
The MTP configuration helpers _get_mtp_num_layers and _set_mtp_num_layers are useful across different modules (e.g., in config_converter.py and megatron_workers.py). They should be made public by removing the leading underscore to follow Python naming conventions for shared utilities and to avoid linting issues when accessed from other packages.
def get_mtp_num_layers(hf_config):
"""Get MTP layer count from various config formats.
Supports:
- num_nextn_predict_layers (DeepSeek, Qwen3 style)
- mtp_num_hidden_layers (Qwen3.5 style, in hf_config or text_config)
"""
if hasattr(hf_config, "num_nextn_predict_layers") and hf_config.num_nextn_predict_layers > 0:
return hf_config.num_nextn_predict_layers
if hasattr(hf_config, "mtp_num_hidden_layers") and hf_config.mtp_num_hidden_layers > 0:
return hf_config.mtp_num_hidden_layers
if hasattr(hf_config, "text_config") and hasattr(hf_config.text_config, "mtp_num_hidden_layers"):
if hf_config.text_config.mtp_num_hidden_layers > 0:
return hf_config.text_config.mtp_num_hidden_layers
return 0
def set_mtp_num_layers(hf_config, value: int):
"""Set MTP layer count in the appropriate config field."""
if hasattr(hf_config, "num_nextn_predict_layers"):
hf_config.num_nextn_predict_layers = value
elif hasattr(hf_config, "mtp_num_hidden_layers"):
hf_config.mtp_num_hidden_layers = value
elif hasattr(hf_config, "text_config") and hasattr(hf_config.text_config, "mtp_num_hidden_layers"):
hf_config.text_config.mtp_num_hidden_layers = valueReferences
- According to PEP-8, a leading underscore is used to indicate that a method or variable is intended for internal use within a module or class. If the utility is intended to be shared across different modules, the leading underscore should be removed to make it public. (link)
| else False | ||
| ) | ||
| hf_config = model_config.hf_config | ||
| mtp_num_layers = _get_mtp_num_layers(hf_config) |
| return | ||
| elif not enable_mtp and has_mtp: | ||
| model_config.hf_config.num_nextn_predict_layers = 0 | ||
| _set_mtp_num_layers(hf_config, 0) |
| transformer_config = check_and_construct_configs(args, TransformerConfig) | ||
|
|
||
| # MTP support: mtp_num_hidden_layers may be in hf_config or hf_config.text_config | ||
| mtp_num_layers = 0 | ||
| if hasattr(hf_config, "mtp_num_hidden_layers"): | ||
| mtp_num_layers = hf_config.mtp_num_hidden_layers | ||
| elif hasattr(hf_config, "text_config") and hasattr(hf_config.text_config, "mtp_num_hidden_layers"): | ||
| mtp_num_layers = hf_config.text_config.mtp_num_hidden_layers | ||
|
|
||
| if mtp_num_layers > 0: | ||
| transformer_config.mtp_num_layers = mtp_num_layers | ||
| transformer_config.mtp_loss_scaling_factor = getattr(hf_config, "mtp_loss_scaling_factor", 0.1) |
There was a problem hiding this comment.
The current implementation has two issues:
- It duplicates the MTP layer detection logic and is less comprehensive than the
get_mtp_num_layershelper (e.g., it missesnum_nextn_predict_layersused by Qwen3 models). - It potentially ignores user overrides for
mtp_loss_scaling_factor. If the scaling factor is provided viaoverride_transformer_config_kwargs, it is added toargsbut then removed bycheck_and_construct_configsif theTransformerConfigclass doesn't explicitly define it. The subsequentgetattr(hf_config, ...)call would then revert to the default value from the model config, ignoring the user's intent.
| transformer_config = check_and_construct_configs(args, TransformerConfig) | |
| # MTP support: mtp_num_hidden_layers may be in hf_config or hf_config.text_config | |
| mtp_num_layers = 0 | |
| if hasattr(hf_config, "mtp_num_hidden_layers"): | |
| mtp_num_layers = hf_config.mtp_num_hidden_layers | |
| elif hasattr(hf_config, "text_config") and hasattr(hf_config.text_config, "mtp_num_hidden_layers"): | |
| mtp_num_layers = hf_config.text_config.mtp_num_hidden_layers | |
| if mtp_num_layers > 0: | |
| transformer_config.mtp_num_layers = mtp_num_layers | |
| transformer_config.mtp_loss_scaling_factor = getattr(hf_config, "mtp_loss_scaling_factor", 0.1) | |
| # Capture MTP scaling factor before it's potentially removed by check_and_construct_configs | |
| mtp_loss_scaling_factor = override_transformer_config_kwargs.get("mtp_loss_scaling_factor", getattr(hf_config, "mtp_loss_scaling_factor", 0.1)) | |
| transformer_config = check_and_construct_configs(args, TransformerConfig) | |
| # MTP support | |
| from verl.utils.megatron_utils import get_mtp_num_layers | |
| mtp_num_layers = get_mtp_num_layers(hf_config) | |
| if mtp_num_layers > 0: | |
| transformer_config.mtp_num_layers = mtp_num_layers | |
| transformer_config.mtp_loss_scaling_factor = mtp_loss_scaling_factor |
| # DeepSeek-style MTP field | ||
| if hasattr(hf_config, "num_nextn_predict_layers"): | ||
| hf_config.num_nextn_predict_layers = 0 | ||
| # Qwen3.5-style MTP field: mtp_num_hidden_layers | ||
| if hasattr(hf_config, "mtp_num_hidden_layers"): | ||
| hf_config.mtp_num_hidden_layers = 0 | ||
| if hasattr(hf_config, "text_config") and hasattr(hf_config.text_config, "mtp_num_hidden_layers"): | ||
| hf_config.text_config.mtp_num_hidden_layers = 0 |
There was a problem hiding this comment.
This logic for disabling MTP fields is redundant with the set_mtp_num_layers helper. Using the helper ensures consistency across the codebase when handling different MTP attribute names (DeepSeek vs Qwen style).
| # DeepSeek-style MTP field | |
| if hasattr(hf_config, "num_nextn_predict_layers"): | |
| hf_config.num_nextn_predict_layers = 0 | |
| # Qwen3.5-style MTP field: mtp_num_hidden_layers | |
| if hasattr(hf_config, "mtp_num_hidden_layers"): | |
| hf_config.mtp_num_hidden_layers = 0 | |
| if hasattr(hf_config, "text_config") and hasattr(hf_config.text_config, "mtp_num_hidden_layers"): | |
| hf_config.text_config.mtp_num_hidden_layers = 0 | |
| # Disable MTP fields | |
| from verl.utils.megatron_utils import set_mtp_num_layers | |
| set_mtp_num_layers(hf_config, 0) |
ArronHZG
left a comment
There was a problem hiding this comment.
please follow gemini‘s recommand to refactor code.
| @@ -0,0 +1,157 @@ | |||
| #!/usr/bin/env bash | |||
There was a problem hiding this comment.
Please reuse examples/sft/gsm8k/run_qwen3_5_megatron.sh with additional MTP option.
| # DeepSeek-style MTP field | ||
| if hasattr(hf_config, "num_nextn_predict_layers"): | ||
| hf_config.num_nextn_predict_layers = 0 | ||
| # Qwen3.5-style MTP field: mtp_num_hidden_layers |
There was a problem hiding this comment.
megatron_worekrs.py has been deprecated, please do not modify it.
What does this PR do?
To run qwen35 with mtp, you should use mbridge with pr: ISEEKYAN/mbridge#98
Checklist Before Starting
[{modules}] {type}: {description}(This will be checked by the CI){modules}includefsdp,megatron,veomni,sglang,vllm,rollout,trainer,ci,training_utils,recipe,hardware,deployment,ray,worker,single_controller,misc,perf,model,algo,env,tool,ckpt,doc,data,cfg,reward,fully_async,one_step_off,like[megatron, fsdp, doc]{type}is infeat,fix,refactor,chore,test[BREAKING]to the beginning of the title.[BREAKING][fsdp, megatron] feat: dynamic batchingTest
API and Usage Example
# Add code snippet or script demonstrating how to use thisDesign & Code Changes
Checklist Before Submitting
Important
Please check all the following items before requesting a review, otherwise the reviewer might deprioritize this PR for review.
pre-commit install && pre-commit run --all-files --show-diff-on-failure --color=alwaysci-requestchannel in theverlSlack workspace. (If not accessible, please try the Feishu group (飞书群).)recipesubmodule, please also update the reference to the submodule commit viagit submodule update --remoteorcd recipe && git pull origin main.