Skip to content
Merged
Show file tree
Hide file tree
Changes from 6 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 20 additions & 0 deletions tests/special_e2e/sft/run_sft_engine.sh
Original file line number Diff line number Diff line change
Expand Up @@ -112,6 +112,22 @@ TORCHTITAN_ENGINE_CONFIG="\
engine.data_parallel_shard_size=${FSDP_SIZE} \
engine.use_torch_compile=False"

AUTOMODEL_ENGINE_CONFIG="\
engine=${backend} \
model=hf_model \
model.path=${MODEL_PATH} \
optim=${backend} \
optim.lr=1e-5 \
optim.lr_warmup_steps_ratio=0.2 \
optim.weight_decay=0.1 \
optim.betas="[0.9,0.95]" \
optim.clip_grad=1.0 \
optim.min_lr_ratio=0.1 \
optim.lr_scheduler_type=cosine \
engine.tp_size=${TP_SIZE} \
engine.cp_size=${CP_SIZE} \
engine.use_torch_compile=False"
Comment thread
HuiyingLi marked this conversation as resolved.


if [ "$backend" = "fsdp" ]; then
ENGINE_CONFIG="$FSDP_ENGINE_CONFIG"
Expand All @@ -125,6 +141,10 @@ elif [ "$backend" = "torchtitan" ]; then
ENGINE_CONFIG="$TORCHTITAN_ENGINE_CONFIG"
echo "Using torchtitan engine"
exp_name=gsm8k-${backend}-tp${TP_SIZE}-pp${PP_SIZE}-cp${CP_SIZE}-dp${FSDP_SIZE}-pad-${PAD_MODE}-use_remove_padding-${USE_REMOVE_PADDING}-mode-${mode}
elif [ "$backend" = "automodel" ]; then
ENGINE_CONFIG="$AUTOMODEL_ENGINE_CONFIG"
echo "Using automodel engine"
exp_name=gsm8k-${backend}-tp${TP_SIZE}-pp${PP_SIZE}-cp${CP_SIZE}-dp${FSDP_SIZE}-pad-${PAD_MODE}-use_remove_padding-${USE_REMOVE_PADDING}-mode-${mode}
else
ENGINE_CONFIG="$MEGATRON_ENGINE_CONFIG"
echo "Using megatron engine"
Expand Down
8 changes: 8 additions & 0 deletions tests/special_e2e/sft/test_sft_engine_all.sh
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,14 @@ BACKEND=megatron TP_SIZE=2 PP_SIZE=2 VPP_SIZE=${VPP_SIZE} CP_SIZE=2 NUM_GPUS=8 m
# echo "run with tp2 pp1 cp1 fsdp2 num_gpus4"
# BACKEND=torchtitan TP_SIZE=2 PP_SIZE=1 CP_SIZE=1 FSDP_SIZE=2 NUM_GPUS=4 bash tests/special_e2e/sft/run_sft_engine.sh

# # test with automodel dp=2
# echo "run with automodel tp1 pp1 cp1 dp2 num_gpus2"
# BACKEND=automodel TP_SIZE=1 PP_SIZE=1 CP_SIZE=1 FSDP_SIZE=2 NUM_GPUS=2 bash tests/special_e2e/sft/run_sft_engine.sh

# # test with automodel tp2 dp=2
# echo "run with automodel tp2 pp1 cp1 dp2 num_gpus4"
# BACKEND=automodel TP_SIZE=2 PP_SIZE=1 CP_SIZE=1 FSDP_SIZE=2 NUM_GPUS=4 bash tests/special_e2e/sft/run_sft_engine.sh

python3 tests/special_e2e/sft/compare_sft_engine_results.py

rm -rf ~/verl/test/log
1 change: 1 addition & 0 deletions tests/special_sanity/check_device_api_usage.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@
"verl/workers/engine/fsdp/transformer_impl.py", # appear in default device_name
"verl/workers/engine/veomni/transformer_impl.py", # appear in default device_name
"verl/workers/engine/torchtitan/transformer_impl.py", # appear in default device_name
"verl/workers/engine/automodel/transformer_impl.py", # appear in default device_name
"verl/workers/rollout/vllm_rollout/vllm_async_server.py", # appear in config.cudagraph_capture_sizes
"verl/workers/rollout/sglang_rollout/async_sglang_server.py", # manually set CUDA_VISIBLE_DEVICES
"verl/workers/rollout/trtllm_rollout/trtllm_async_server.py", # appear in config.cudagraph_capture_sizes
Expand Down
76 changes: 76 additions & 0 deletions verl/trainer/config/engine/automodel.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,76 @@
# Target class for this configuration
_target_: verl.workers.config.AutomodelEngineConfig

# Backend strategy identifier
strategy: automodel

# Distributed training strategy: "fsdp2", "megatron_fsdp", or "ddp"
distributed_strategy: fsdp2

# Parallelism sizes
tp_size: 1
pp_size: 1
cp_size: 1
ep_size: 1
dp_replicate_size: 1
sequence_parallel: false
defer_fsdp_grad_sync: true

# Whether to offload model parameters to CPU
param_offload: false

# Whether to offload optimizer state to CPU
optimizer_offload: false

# Whether to enable activation checkpointing
activation_checkpointing: false

# Whether to enable FP8 training
enable_fp8: false

# Whether to enable torch.compile for the model
enable_compile: false

# Model data type for loading weights ("fp32", "bf16", "fp16")
model_dtype: fp32

# Attention implementation ("sdpa", "flash_attention_2", "eager", "te")
attn_implementation: sdpa

# Backend settings (nemo_automodel BackendConfig)
use_te_backend: false
rope_fusion: true
gate_precision: null
enable_hf_state_dict_adapter: true
enable_fsdp_optimizations: false

# MoE / Expert Parallelism settings
enable_deepep: false
reshard_after_forward: false
fake_balanced_gate: false
ignore_router_for_ac: false
lm_head_precision: null
wrap_outer_model: true

# Mixed precision policy (FSDP2 MixedPrecisionPolicy)
mp_param_dtype: bf16
mp_reduce_dtype: fp32
mp_output_dtype: bf16

# Random seed for reproducibility
seed: 42

# Whether to enable full determinism for distributed training, only for debugging
full_determinism: false

# Whether to use forward only mode
forward_only: false

# Whether to use torch compile for entropy computation
use_torch_compile: false

# Whether to use chunked entropy computation
entropy_from_logits_with_chunking: false

# Whether to use checkpointing for entropy computation
entropy_checkpointing: false
47 changes: 47 additions & 0 deletions verl/trainer/config/optim/automodel.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
# Target class for this configuration
_target_: verl.workers.config.AutomodelOptimizerConfig

optimizer: AdamW

# Module path to import optimizer from
optimizer_impl: torch.optim

# Learning rate (maps to max_lr in Automodel's OptimizerParamScheduler)
lr: 1e-5

# LR warmup steps ratio (used when lr_warmup_steps <= 0)
lr_warmup_steps_ratio: 0.0

# Total training steps (injected at runtime)
total_training_steps: -1

# Weight decay
weight_decay: 0.01

# LR warmup steps (set > 0 to override lr_warmup_steps_ratio)
lr_warmup_steps: -1

# Betas for Adam optimizer
betas: [0.9, 0.999]

# Clip gradient norm
clip_grad: 1.0

# Initial LR ratio for warmup start (init_lr = lr * init_lr_ratio)
init_lr_ratio: 0.1

# Minimum LR ratio after decay (min_lr = lr * min_lr_ratio)
min_lr_ratio: 0.01

# LR scheduler type (Automodel OptimizerParamScheduler decay style)
# Options: "constant", "cosine", "linear", "inverse-square-root"
lr_scheduler_type: cosine

# Weight decay increment style: "constant", "linear", or "cosine"
wd_incr_style: constant

# Kept for backward compatibility (unused by Automodel scheduler)
num_cycles: 0.5
zero_indexed_step: true

override_optimizer_config: {}
3 changes: 2 additions & 1 deletion verl/utils/dataset/multiturn_sft_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,8 @@ def print_assembled_message(tokenizer, message_list, input_ids, loss_mask, attn_
sep = "\n\n"
str = f"tokenized entire message:\n{tokenized}"
str += sep
str += f"tokenized seperately :\n{tokenizer.decode(input_ids)}"
decoded_ids = input_ids.tolist() if hasattr(input_ids, "tolist") else input_ids
str += f"tokenized seperately :\n{tokenizer.decode(decoded_ids)}"

logger.debug(str)

Expand Down
99 changes: 99 additions & 0 deletions verl/workers/config/engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
"TrainingWorkerConfig",
"TorchtitanEngineConfig",
"VeOmniEngineConfig",
"AutomodelEngineConfig",
"EngineConfig",
"EngineRouterReplayConfig",
]
Expand Down Expand Up @@ -369,6 +370,104 @@ def __post_init__(self):
assert self.strategy in ["torchtitan"], f"strategy {self.strategy} not supported"


@dataclass
class AutomodelEngineConfig(EngineConfig):
"""Configuration for Automodel (nemo_automodel) backend.

The Automodel backend uses NeMoAutoModelForCausalLM for model loading and
supports FSDP2, MegatronFSDP, and DDP distributed strategies with optional
TP, CP, and EP parallelism.

Args:
strategy (str): Backend strategy identifier, must be "automodel".
distributed_strategy (str): Distributed training strategy: "fsdp2", "megatron_fsdp", or "ddp".
tp_size (int): Tensor parallel size.
pp_size (int): Pipeline parallel size (only pp_size=1 supported initially).
cp_size (int): Context parallel size.
ep_size (int): Expert parallel size for MoE models.
dp_replicate_size (int): Data-parallel replicate size for HSDP. 1 = pure sharding.
sequence_parallel (bool): Enable sequence parallelism in the TP plan.
defer_fsdp_grad_sync (bool): Defer FSDP gradient sync to the final micro-batch.
activation_checkpointing (bool): Whether to enable activation checkpointing.
enable_fp8 (bool): Whether to enable FP8 training.
enable_compile (bool): Whether to enable torch.compile for the model.
model_dtype (str): Model data type for loading weights. "fp32" loads in float32
(matching FSDP golden), "auto" uses the dtype from the model config.
attn_implementation (str): Attention implementation to use ("sdpa", "flash_attention_2", "eager", "te").

Backend settings (nemo_automodel BackendConfig):
use_te_backend (bool): Use TransformerEngine attn/linear/rms_norm.
rope_fusion (bool): Enable RoPE fusion (requires TransformerEngine).
gate_precision (Optional[str]): Precision for MoE gate/router weights (e.g. "fp32", "bf16").
enable_hf_state_dict_adapter (bool): Enable HuggingFace state dict compatibility.
enable_fsdp_optimizations (bool): Enable FSDP-specific optimizations in TE layers.

MoE / Expert Parallelism settings:
enable_deepep (bool): Enable DeepEP for distributed expert parallelism.
reshard_after_forward (bool): Reshard parameters after forward pass in MoE parallelizer.
fake_balanced_gate (bool): Use balanced gate for performance analysis.
ignore_router_for_ac (bool): Use selective activation checkpointing that saves router outputs.
lm_head_precision (Optional[str]): Custom precision for lm_head layer (e.g. "fp32").
wrap_outer_model (bool): Wrap outer model in FSDP if it differs from inner model.

Mixed precision policy (FSDP2):
mp_param_dtype (str): Parameter dtype for FSDP2 mixed precision policy.
mp_reduce_dtype (str): Reduce dtype for FSDP2 mixed precision policy.
mp_output_dtype (str): Output dtype for FSDP2 mixed precision policy.

Entropy computation:
entropy_from_logits_with_chunking (bool): Whether to use chunked entropy computation.
use_torch_compile (bool): Whether to use torch.compile for entropy computation.
entropy_checkpointing (bool): Whether to use checkpointing for entropy computation.
"""

strategy: str = "automodel"
distributed_strategy: str = "fsdp2"
# Parallelism sizes
tp_size: int = 1
pp_size: int = 1
cp_size: int = 1
ep_size: int = 1
dp_replicate_size: int = 1
sequence_parallel: bool = False
defer_fsdp_grad_sync: bool = True
# Model settings
activation_checkpointing: bool = False
enable_fp8: bool = False
enable_compile: bool = False
model_dtype: str = "fp32"
attn_implementation: str = "sdpa"
# Backend settings (BackendConfig)
use_te_backend: bool = False
rope_fusion: bool = True
gate_precision: Optional[str] = None
enable_hf_state_dict_adapter: bool = True
enable_fsdp_optimizations: bool = False
# MoE / Expert Parallelism settings
enable_deepep: bool = False
reshard_after_forward: bool = False
fake_balanced_gate: bool = False
ignore_router_for_ac: bool = False
lm_head_precision: Optional[str] = None
wrap_outer_model: bool = True
# Mixed precision policy
mp_param_dtype: str = "bf16"
mp_reduce_dtype: str = "fp32"
mp_output_dtype: str = "bf16"
# Entropy computation
entropy_from_logits_with_chunking: bool = False
use_torch_compile: bool = True
entropy_checkpointing: bool = False

def __post_init__(self):
super().__post_init__()
assert self.strategy == "automodel", f"strategy must be 'automodel', got {self.strategy}"
assert self.distributed_strategy in ["fsdp2", "megatron_fsdp", "ddp"], (
f"distributed_strategy {self.distributed_strategy} not supported"
)
assert self.pp_size == 1, "Pipeline parallelism (pp_size > 1) is not yet supported for automodel backend"


@dataclass
class TrainingWorkerConfig(BaseConfig):
model_type: str = None # model type (language_model/value_model)
Expand Down
38 changes: 38 additions & 0 deletions verl/workers/config/optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
"build_optimizer",
"VeOmniOptimizerConfig",
"TorchtitanOptimizerConfig",
"AutomodelOptimizerConfig",
]


Expand Down Expand Up @@ -170,6 +171,43 @@ class TorchtitanOptimizerConfig(OptimizerConfig):
min_lr_factor: float = 0.0


@dataclass
class AutomodelOptimizerConfig(OptimizerConfig):
"""Automodel optimizer configuration extending base OptimizerConfig.

Uses the same optimizer building mechanism as FSDP (dynamic import from optimizer_impl).
LR scheduling is handled by Automodel's OptimizerParamScheduler.

Args:
optimizer (str): Optimizer class name (e.g., "AdamW").
optimizer_impl (str): Module path to import optimizer from (e.g., "torch.optim").
lr (float): Learning rate (maps to max_lr in OptimizerParamScheduler).
init_lr_ratio (Optional[float]): Initial LR ratio for warmup start (init_lr = lr * init_lr_ratio).
min_lr_ratio (Optional[float]): Minimum LR ratio after decay (min_lr = lr * min_lr_ratio).
lr_scheduler_type (str): LR decay style: "constant", "cosine", "linear", or "inverse-square-root".
wd_incr_style (str): Weight decay increment style: "constant", "linear", or "cosine".
num_cycles (float): Kept for backward compatibility (unused by Automodel scheduler).
zero_indexed_step (bool): Kept for backward compatibility (unused by Automodel scheduler).
"""

_mutable_fields = OptimizerConfig._mutable_fields.copy()
_mutable_fields.add("lr_scheduler_type")

optimizer: str = "AdamW"
optimizer_impl: str = "torch.optim"
init_lr_ratio: Optional[float] = 0.1
min_lr_ratio: Optional[float] = 0.01
lr_scheduler_type: str = "cosine"
wd_incr_style: str = "constant"
num_cycles: float = 0.5
override_optimizer_config: Optional[dict] = None
zero_indexed_step: bool = True

def __post_init__(self):
assert self.lr_scheduler_type in ["constant", "cosine", "linear", "inverse-square-root"]
return super().__post_init__()


def build_optimizer(parameters, config: FSDPOptimizerConfig):
"""Build an optimizer based on the configuration.

Expand Down
8 changes: 8 additions & 0 deletions verl/workers/engine/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,14 @@
VeOmniEngine = None
VeOmniEngineWithLMHead = None

try:
from .automodel import AutomodelEngine, AutomodelEngineWithLMHead

__all__ += ["AutomodelEngine", "AutomodelEngineWithLMHead"]
except ImportError:
AutomodelEngine = None
AutomodelEngineWithLMHead = None

# Mindspeed must be imported before Megatron to ensure the related monkey patches take effect as expected
try:
from .mindspeed import MindspeedEngineWithLMHead
Expand Down
20 changes: 20 additions & 0 deletions verl/workers/engine/automodel/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
# Copyright 2025 Bytedance Ltd. and/or its affiliates
Comment thread
HuiyingLi marked this conversation as resolved.
Outdated
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from .transformer_impl import AutomodelEngine, AutomodelEngineWithLMHead

__all__ = [
"AutomodelEngine",
"AutomodelEngineWithLMHead",
]
Loading
Loading