Skip to content

Commit 8849aaa

Browse files
authored
[veomni] refactor: minor refactoring to ensure veomni engine compatibility with forward_only mode (#4889)
### What does this PR do? Refactored the Vomni engine's initialize method into smaller sub-methods, and enabling compatibility with forward_only mode. ### Checklist Before Starting - [ ] Search for similar PRs. Paste at least one query link here: ... - [ ] Format the PR title as `[{modules}] {type}: {description}` (This will be checked by the CI) - `{modules}` include `fsdp`, `megatron`, `sglang`, `vllm`, `rollout`, `trainer`, `ci`, `training_utils`, `recipe`, `hardware`, `deployment`, `ray`, `worker`, `single_controller`, `misc`, `perf`, `model`, `algo`, `env`, `tool`, `ckpt`, `doc`, `data`, `cfg`, `reward` - If this PR involves multiple modules, separate them with `,` like `[megatron, fsdp, doc]` - `{type}` is in `feat`, `fix`, `refactor`, `chore`, `test` - If this PR breaks any API (CLI arguments, config, function signature, etc.), add `[BREAKING]` to the beginning of the title. - Example: `[BREAKING][fsdp, megatron] feat: dynamic batching` ### Test > For changes that can not be tested by CI (e.g., algorithm implementation, new model support), validate by experiment(s) and show results like training curve plots, evaluation results, etc. ### API and Usage Example > Demonstrate how the API changes if any, and provide usage example(s) if possible. ```python # Add code snippet or script demonstrating how to use this ``` ### Design & Code Changes > Demonstrate the high-level design if this PR is complex, and list the specific changes. ### Checklist Before Submitting > [!IMPORTANT] > Please check all the following items before requesting a review, otherwise the reviewer might deprioritize this PR for review. - [ ] Read the [Contribute Guide](https://github.com/volcengine/verl/blob/main/CONTRIBUTING.md). - [ ] Apply [pre-commit checks](https://github.com/volcengine/verl/blob/main/CONTRIBUTING.md#code-linting-and-formatting): `pre-commit install && pre-commit run --all-files --show-diff-on-failure --color=always` - [ ] Add / Update [the documentation](https://github.com/volcengine/verl/tree/main/docs). - [ ] Add unit or end-to-end test(s) to [the CI workflow](https://github.com/volcengine/verl/tree/main/.github/workflows) to cover all the code. If not feasible, explain why: ... - [ ] Once your PR is ready for CI, send a message in [the `ci-request` channel](https://verl-project.slack.com/archives/C091TCESWB1) in [the `verl` Slack workspace](https://join.slack.com/t/verl-project/shared_invite/zt-3855yhg8g-CTkqXu~hKojPCmo7k_yXTQ). (If not accessible, please try [the Feishu group (飞书群)](https://applink.larkoffice.com/client/chat/chatter/add_by_link?link_token=772jd4f1-cd91-441e-a820-498c6614126a).) - [ ] If your PR is related to the `recipe` submodule, please also update the reference to the submodule commit via `git submodule update --remote` or `cd recipe && git pull origin main`.
1 parent d1fb320 commit 8849aaa

File tree

3 files changed

+75
-48
lines changed

3 files changed

+75
-48
lines changed

.github/PULL_REQUEST_TEMPLATE.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66

77
- [ ] Search for similar PRs. Paste at least one query link here: ...
88
- [ ] Format the PR title as `[{modules}] {type}: {description}` (This will be checked by the CI)
9-
- `{modules}` include `fsdp`, `megatron`, `sglang`, `vllm`, `rollout`, `trainer`, `ci`, `training_utils`, `recipe`, `hardware`, `deployment`, `ray`, `worker`, `single_controller`, `misc`, `perf`, `model`, `algo`, `env`, `tool`, `ckpt`, `doc`, `data`, `cfg`, `reward`
9+
- `{modules}` include `fsdp`, `megatron`, `veomni`, `sglang`, `vllm`, `rollout`, `trainer`, `ci`, `training_utils`, `recipe`, `hardware`, `deployment`, `ray`, `worker`, `single_controller`, `misc`, `perf`, `model`, `algo`, `env`, `tool`, `ckpt`, `doc`, `data`, `cfg`, `reward`
1010
- If this PR involves multiple modules, separate them with `,` like `[megatron, fsdp, doc]`
1111
- `{type}` is in `feat`, `fix`, `refactor`, `chore`, `test`
1212
- If this PR breaks any API (CLI arguments, config, function signature, etc.), add `[BREAKING]` to the beginning of the title.

tests/special_sanity/check_pr_title.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@
1919
pr_title = os.environ.get("PR_TITLE", "").strip()
2020

2121
# Define rules
22-
allowed_modules = ["fsdp", "megatron", "sglang", "vllm", "rollout", "trainer"]
22+
allowed_modules = ["fsdp", "megatron", "veomni", "sglang", "vllm", "rollout", "trainer"]
2323
allowed_modules += ["tests", "training_utils", "recipe", "hardware", "deployment"]
2424
allowed_modules += ["ray", "worker", "single_controller", "misc", "docker", "ci"]
2525
allowed_modules += ["perf", "model", "algo", "env", "tool", "ckpt", "doc", "data", "cfg", "reward"]

verl/workers/engine/veomni/transformer_impl.py

Lines changed: 73 additions & 46 deletions
Original file line numberDiff line numberDiff line change
@@ -31,12 +31,9 @@
3131
from verl.trainer.config import CheckpointConfig
3232
from verl.utils import tensordict_utils as tu
3333
from verl.utils.checkpoint.fsdp_checkpoint_manager import FSDPCheckpointManager
34-
from verl.utils.device import (
35-
get_device_id,
36-
)
37-
from verl.utils.fsdp_utils import (
38-
fsdp_version,
39-
)
34+
from verl.utils.device import get_device_id
35+
from verl.utils.fsdp_utils import fsdp_version
36+
from verl.utils.profiler import log_gpu_memory_usage
4037
from verl.workers.config import HFModelConfig, VeOmniEngineConfig, VeOmniOptimizerConfig
4138
from verl.workers.sharding_manager.fsdp_ulysses import FSDPUlyssesShardingManager
4239

@@ -95,7 +92,6 @@ def __init__(
9592

9693
self.use_remove_padding = self.model_config.use_remove_padding
9794

98-
# set FSDP offload params
9995
self._is_offload_param = self.engine_config.param_offload
10096
self._is_offload_optimizer = self.engine_config.optimizer_offload
10197
self._is_lora = self.model_config.lora_rank > 0
@@ -121,69 +117,100 @@ def __init__(
121117

122118
def initialize(self):
123119
"""
124-
Build the model, optimizer, and learning rate scheduler under FSDP.
120+
Build the model, optimizer, and learning rate scheduler under VeOmni.
125121
126122
Applies device, dtype, and precision configurations, including mixed precision.
127123
Sets up checkpoint manager and FLOPs counter.
128124
"""
125+
self._build_model_optimizer()
129126

130-
self.module = build_foundation_model(
127+
self.checkpoint_manager = FSDPCheckpointManager(
128+
model=self.module,
129+
optimizer=self.optimizer,
130+
lr_scheduler=self.lr_scheduler,
131+
processing_class=self.model_config.get_processor(),
132+
checkpoint_config=self.checkpoint_config,
133+
)
134+
135+
self.to(
136+
device="cpu",
137+
model=self._is_offload_param,
138+
optimizer=self._is_offload_optimizer,
139+
grad=self._is_offload_optimizer,
140+
)
141+
142+
log_gpu_memory_usage("After offload model/optimizer/grad during init", logger=logger)
143+
144+
def _build_optimizer(self, module):
145+
optimizer = build_optimizer(
146+
module,
147+
lr=self.optimizer_config.lr,
148+
betas=self.optimizer_config.betas,
149+
weight_decay=self.optimizer_config.weight_decay,
150+
optimizer_type=self.optimizer_config.optimizer,
151+
)
152+
get_optimizer_pre_hook = getattr(module, "get_optimizer_pre_hook", None)
153+
if get_optimizer_pre_hook is not None:
154+
optimizer_pre_hook = get_optimizer_pre_hook(module, module.config, self.engine_config.data_parallel_mode)
155+
optimizer.register_step_pre_hook(optimizer_pre_hook)
156+
157+
return optimizer
158+
159+
def _build_lr_scheduler(self, optimizer):
160+
optim_config = self.optimizer_config
161+
lr_scheduler = build_lr_scheduler(
162+
optimizer,
163+
train_steps=optim_config.total_training_steps,
164+
lr=optim_config.lr,
165+
lr_min=optim_config.lr_min,
166+
lr_decay_style=optim_config.lr_scheduler_type,
167+
lr_decay_ratio=optim_config.lr_decay_ratio,
168+
lr_warmup_ratio=optim_config.lr_warmup_steps_ratio,
169+
lr_start=optim_config.lr_start,
170+
)
171+
172+
return lr_scheduler
173+
174+
def _build_model_optimizer(self):
175+
# Load base model with specified configuration and dtype
176+
module = build_foundation_model(
131177
config_path=self.model_config.hf_config_path,
132178
weights_path=self.model_config.path,
133179
torch_dtype="float32" if self.engine_config.mixed_precision else "bfloat16",
134180
attn_implementation=self.engine_config.attn_implementation,
135181
moe_implementation=self.engine_config.moe_implementation,
136182
init_device=self.engine_config.init_device,
137183
)
184+
log_gpu_memory_usage("After load base model", logger=logger)
138185

139-
module_config = self.module.config
140-
141-
get_optimizer_pre_hook = getattr(self.module, "get_optimizer_pre_hook", None)
142-
self.module = build_parallelize_model(
143-
self.module,
186+
# Applies parallel strategies to the model.
187+
log_gpu_memory_usage("Before parallelize model", logger=logger)
188+
module = build_parallelize_model(
189+
module,
144190
init_device=self.engine_config.init_device,
145191
weights_path=self.model_config.path,
146192
enable_full_shard=self.engine_config.enable_full_shard,
147193
enable_mixed_precision=self.engine_config.mixed_precision,
148194
enable_gradient_checkpointing=self.model_config.enable_gradient_checkpointing,
149195
enable_fsdp_offload=self.engine_config.enable_fsdp_offload,
150-
basic_modules=self.module._no_split_modules + self.engine_config.basic_modules,
196+
basic_modules=module._no_split_modules + self.engine_config.basic_modules,
151197
enable_reentrant=self.engine_config.enable_reentrant,
152198
enable_forward_prefetch=self.engine_config.forward_prefetch,
153199
)
200+
log_gpu_memory_usage("After parallelize model", logger=logger)
154201

155-
self.optimizer = build_optimizer(
156-
self.module,
157-
lr=self.optimizer_config.lr,
158-
betas=self.optimizer_config.betas,
159-
weight_decay=self.optimizer_config.weight_decay,
160-
optimizer_type=self.optimizer_config.optimizer,
161-
)
162-
if get_optimizer_pre_hook is not None:
163-
optimizer_pre_hook = get_optimizer_pre_hook(
164-
self.module, module_config, self.engine_config.data_parallel_mode
165-
)
166-
self.optimizer.register_step_pre_hook(optimizer_pre_hook)
167-
168-
self.lr_scheduler = build_lr_scheduler(
169-
self.optimizer,
170-
train_steps=self.optimizer_config.total_training_steps,
171-
lr=self.optimizer_config.lr,
172-
lr_min=self.optimizer_config.lr_min,
173-
lr_decay_style=self.optimizer_config.lr_scheduler_type,
174-
lr_decay_ratio=self.optimizer_config.lr_decay_ratio,
175-
lr_warmup_ratio=self.optimizer_config.lr_warmup_steps_ratio,
176-
lr_start=self.optimizer_config.lr_start,
177-
)
178-
179-
self.checkpoint_manager = FSDPCheckpointManager(
180-
model=self.module,
181-
optimizer=self.optimizer,
182-
lr_scheduler=self.lr_scheduler,
183-
processing_class=self.model_config.get_processor(),
184-
checkpoint_contents=self.checkpoint_config,
185-
)
202+
if not self.engine_config.forward_only:
203+
# Initialize optimizer with model parameters and config settings
204+
optimizer = self._build_optimizer(module)
205+
# Create learning rate scheduler with warmup and decay settings
206+
lr_scheduler = self._build_lr_scheduler(optimizer)
207+
else:
208+
optimizer = None
209+
lr_scheduler = None
186210

211+
self.module = module
212+
self.optimizer = optimizer
213+
self.lr_scheduler = lr_scheduler
187214
self.model_fwd_context, self.model_bwd_context = build_activation_offloading_context(
188215
self.model_config.enable_activation_offload,
189216
self.model_config.enable_gradient_checkpointing,

0 commit comments

Comments
 (0)