Skip to content

Commit f9558aa

Browse files
committed
add get_per_tensor_param
1 parent a3f2c16 commit f9558aa

File tree

1 file changed

+90
-29
lines changed

1 file changed

+90
-29
lines changed

verl/workers/engine/veomni/transformer_impl.py

Lines changed: 90 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515

1616
import logging
1717
from dataclasses import dataclass, field
18-
from typing import Any, Callable, Sequence
18+
from typing import Any, Callable, Optional, Sequence
1919

2020
import torch
2121
import torch.distributed as dist
@@ -33,6 +33,7 @@
3333
from verl.utils.checkpoint.fsdp_checkpoint_manager import FSDPCheckpointManager
3434
from verl.utils.device import get_device_id, get_device_name
3535
from verl.utils.fsdp_utils import fsdp_version
36+
from verl.utils.model import convert_weight_keys
3637
from verl.utils.profiler import log_gpu_memory_usage
3738
from verl.utils.veomni_utils import (
3839
load_veomni_model_to_gpu,
@@ -223,34 +224,6 @@ def _build_model_optimizer(self):
223224
self.engine_config.activation_gpu_limit,
224225
)
225226

226-
def to(self, device: str, model: bool = True, optimizer: bool = True, grad: bool = True):
227-
"""
228-
Move model parameters, optimizer states, or both to the specified device.
229-
Note that this function executes irrespective of offload config. It serves as manual control.
230-
231-
Args:
232-
device: Target device identifier.
233-
model: If True, move the model.
234-
optimizer: If True, move the optimizer states.
235-
"""
236-
super(FSDPEngine, self).to(device=device, model=model, optimizer=optimizer, grad=grad)
237-
238-
device_name = get_device_name()
239-
240-
assert device in (device_name, "cpu")
241-
if device == device_name:
242-
if model:
243-
load_veomni_model_to_gpu(self.module)
244-
if optimizer and self.optimizer is not None:
245-
load_veomni_optimizer(self.optimizer, device)
246-
elif device == "cpu":
247-
if model:
248-
offload_veomni_model_to_cpu(self.module)
249-
if optimizer and self.optimizer is not None:
250-
offload_veomni_optimizer(self.optimizer)
251-
else:
252-
raise ValueError(f"Invalid device type: {device}")
253-
254227
def optimizer_step(self):
255228
"""
256229
Perform an optimization step using the optimizer.
@@ -347,6 +320,94 @@ def eval_mode(self, **kwargs):
347320
Includes activation offload entry/exit.
348321
"""
349322
return EngineEvalModeCtx(self, **kwargs)
323+
324+
def to(self, device: str, model: bool = True, optimizer: bool = True, grad: bool = True):
325+
"""
326+
Move model parameters, optimizer states, or both to the specified device.
327+
Note that this function executes irrespective of offload config. It serves as manual control.
328+
329+
Args:
330+
device: Target device identifier.
331+
model: If True, move the model.
332+
optimizer: If True, move the optimizer states.
333+
"""
334+
super(FSDPEngine, self).to(device=device, model=model, optimizer=optimizer, grad=grad)
335+
336+
device_name = get_device_name()
337+
338+
assert device in (device_name, "cpu")
339+
if device == device_name:
340+
if model:
341+
load_veomni_model_to_gpu(self.module)
342+
if optimizer and self.optimizer is not None:
343+
load_veomni_optimizer(self.optimizer, device)
344+
elif device == "cpu":
345+
if model:
346+
offload_veomni_model_to_cpu(self.module)
347+
if optimizer and self.optimizer is not None:
348+
offload_veomni_optimizer(self.optimizer)
349+
else:
350+
raise ValueError(f"Invalid device type: {device}")
351+
352+
def save_checkpoint(
353+
self,
354+
local_path: str,
355+
hdfs_path: Optional[str] = None,
356+
global_step: int = 0,
357+
max_ckpt_to_keep: Optional[int] = None,
358+
**kwargs,
359+
) -> None:
360+
"""
361+
Save VeOmni checkpoint, handling parameter offload as needed.
362+
"""
363+
origin_module_device = next(self.module.parameters()).device.type
364+
if self._is_offload_param or origin_module_device == "cpu":
365+
load_veomni_model_to_gpu(self.module)
366+
367+
self.checkpoint_manager.save_checkpoint(
368+
local_path=local_path, hdfs_path=hdfs_path, global_step=global_step, max_ckpt_to_keep=max_ckpt_to_keep
369+
)
370+
371+
torch.distributed.barrier()
372+
if self._is_offload_param:
373+
offload_veomni_model_to_cpu(self.module)
374+
375+
def load_checkpoint(
376+
self, local_path: str, hdfs_path: Optional[str] = None, del_local_after_load: int = True, **kwargs
377+
) -> None:
378+
"""
379+
Load VeOmni checkpoint, restoring parameters and optimizer state.
380+
"""
381+
if self._is_offload_param:
382+
load_veomni_model_to_gpu(self.module)
383+
384+
self.checkpoint_manager.load_checkpoint(
385+
local_path=local_path, hdfs_path=hdfs_path, del_local_after_load=del_local_after_load
386+
)
387+
388+
torch.distributed.barrier()
389+
if self._is_offload_param:
390+
offload_veomni_model_to_cpu(self.module)
391+
392+
if self._is_offload_optimizer:
393+
offload_veomni_optimizer(self.optimizer)
394+
395+
def get_per_tensor_param(self, **kwargs):
396+
load_veomni_model_to_gpu(self.module)
397+
398+
params = self.module.state_dict()
399+
params = convert_weight_keys(params, getattr(self.module, "_fsdp_wrapped_module", self.module))
400+
401+
if self._is_offload_param:
402+
offload_veomni_model_to_cpu(self.module)
403+
404+
device = get_device_id()
405+
per_tensor_param = (
406+
(name, param.to(device, non_blocking=True).full_tensor() if isinstance(param, DTensor) else param)
407+
for name, param in params.items()
408+
)
409+
# TODO: support veomni LoRA
410+
return per_tensor_param, None
350411

351412

352413
class EngineEvalModeCtx(BaseEngineCtx):

0 commit comments

Comments
 (0)