|
15 | 15 |
|
16 | 16 | import logging |
17 | 17 | from dataclasses import dataclass, field |
18 | | -from typing import Any, Callable, Sequence |
| 18 | +from typing import Any, Callable, Optional, Sequence |
19 | 19 |
|
20 | 20 | import torch |
21 | 21 | import torch.distributed as dist |
|
33 | 33 | from verl.utils.checkpoint.fsdp_checkpoint_manager import FSDPCheckpointManager |
34 | 34 | from verl.utils.device import get_device_id, get_device_name |
35 | 35 | from verl.utils.fsdp_utils import fsdp_version |
| 36 | +from verl.utils.model import convert_weight_keys |
36 | 37 | from verl.utils.profiler import log_gpu_memory_usage |
37 | 38 | from verl.utils.veomni_utils import ( |
38 | 39 | load_veomni_model_to_gpu, |
@@ -223,34 +224,6 @@ def _build_model_optimizer(self): |
223 | 224 | self.engine_config.activation_gpu_limit, |
224 | 225 | ) |
225 | 226 |
|
226 | | - def to(self, device: str, model: bool = True, optimizer: bool = True, grad: bool = True): |
227 | | - """ |
228 | | - Move model parameters, optimizer states, or both to the specified device. |
229 | | - Note that this function executes irrespective of offload config. It serves as manual control. |
230 | | -
|
231 | | - Args: |
232 | | - device: Target device identifier. |
233 | | - model: If True, move the model. |
234 | | - optimizer: If True, move the optimizer states. |
235 | | - """ |
236 | | - super(FSDPEngine, self).to(device=device, model=model, optimizer=optimizer, grad=grad) |
237 | | - |
238 | | - device_name = get_device_name() |
239 | | - |
240 | | - assert device in (device_name, "cpu") |
241 | | - if device == device_name: |
242 | | - if model: |
243 | | - load_veomni_model_to_gpu(self.module) |
244 | | - if optimizer and self.optimizer is not None: |
245 | | - load_veomni_optimizer(self.optimizer, device) |
246 | | - elif device == "cpu": |
247 | | - if model: |
248 | | - offload_veomni_model_to_cpu(self.module) |
249 | | - if optimizer and self.optimizer is not None: |
250 | | - offload_veomni_optimizer(self.optimizer) |
251 | | - else: |
252 | | - raise ValueError(f"Invalid device type: {device}") |
253 | | - |
254 | 227 | def optimizer_step(self): |
255 | 228 | """ |
256 | 229 | Perform an optimization step using the optimizer. |
@@ -347,6 +320,94 @@ def eval_mode(self, **kwargs): |
347 | 320 | Includes activation offload entry/exit. |
348 | 321 | """ |
349 | 322 | return EngineEvalModeCtx(self, **kwargs) |
| 323 | + |
| 324 | + def to(self, device: str, model: bool = True, optimizer: bool = True, grad: bool = True): |
| 325 | + """ |
| 326 | + Move model parameters, optimizer states, or both to the specified device. |
| 327 | + Note that this function executes irrespective of offload config. It serves as manual control. |
| 328 | +
|
| 329 | + Args: |
| 330 | + device: Target device identifier. |
| 331 | + model: If True, move the model. |
| 332 | + optimizer: If True, move the optimizer states. |
| 333 | + """ |
| 334 | + super(FSDPEngine, self).to(device=device, model=model, optimizer=optimizer, grad=grad) |
| 335 | + |
| 336 | + device_name = get_device_name() |
| 337 | + |
| 338 | + assert device in (device_name, "cpu") |
| 339 | + if device == device_name: |
| 340 | + if model: |
| 341 | + load_veomni_model_to_gpu(self.module) |
| 342 | + if optimizer and self.optimizer is not None: |
| 343 | + load_veomni_optimizer(self.optimizer, device) |
| 344 | + elif device == "cpu": |
| 345 | + if model: |
| 346 | + offload_veomni_model_to_cpu(self.module) |
| 347 | + if optimizer and self.optimizer is not None: |
| 348 | + offload_veomni_optimizer(self.optimizer) |
| 349 | + else: |
| 350 | + raise ValueError(f"Invalid device type: {device}") |
| 351 | + |
| 352 | + def save_checkpoint( |
| 353 | + self, |
| 354 | + local_path: str, |
| 355 | + hdfs_path: Optional[str] = None, |
| 356 | + global_step: int = 0, |
| 357 | + max_ckpt_to_keep: Optional[int] = None, |
| 358 | + **kwargs, |
| 359 | + ) -> None: |
| 360 | + """ |
| 361 | + Save VeOmni checkpoint, handling parameter offload as needed. |
| 362 | + """ |
| 363 | + origin_module_device = next(self.module.parameters()).device.type |
| 364 | + if self._is_offload_param or origin_module_device == "cpu": |
| 365 | + load_veomni_model_to_gpu(self.module) |
| 366 | + |
| 367 | + self.checkpoint_manager.save_checkpoint( |
| 368 | + local_path=local_path, hdfs_path=hdfs_path, global_step=global_step, max_ckpt_to_keep=max_ckpt_to_keep |
| 369 | + ) |
| 370 | + |
| 371 | + torch.distributed.barrier() |
| 372 | + if self._is_offload_param: |
| 373 | + offload_veomni_model_to_cpu(self.module) |
| 374 | + |
| 375 | + def load_checkpoint( |
| 376 | + self, local_path: str, hdfs_path: Optional[str] = None, del_local_after_load: int = True, **kwargs |
| 377 | + ) -> None: |
| 378 | + """ |
| 379 | + Load VeOmni checkpoint, restoring parameters and optimizer state. |
| 380 | + """ |
| 381 | + if self._is_offload_param: |
| 382 | + load_veomni_model_to_gpu(self.module) |
| 383 | + |
| 384 | + self.checkpoint_manager.load_checkpoint( |
| 385 | + local_path=local_path, hdfs_path=hdfs_path, del_local_after_load=del_local_after_load |
| 386 | + ) |
| 387 | + |
| 388 | + torch.distributed.barrier() |
| 389 | + if self._is_offload_param: |
| 390 | + offload_veomni_model_to_cpu(self.module) |
| 391 | + |
| 392 | + if self._is_offload_optimizer: |
| 393 | + offload_veomni_optimizer(self.optimizer) |
| 394 | + |
| 395 | + def get_per_tensor_param(self, **kwargs): |
| 396 | + load_veomni_model_to_gpu(self.module) |
| 397 | + |
| 398 | + params = self.module.state_dict() |
| 399 | + params = convert_weight_keys(params, getattr(self.module, "_fsdp_wrapped_module", self.module)) |
| 400 | + |
| 401 | + if self._is_offload_param: |
| 402 | + offload_veomni_model_to_cpu(self.module) |
| 403 | + |
| 404 | + device = get_device_id() |
| 405 | + per_tensor_param = ( |
| 406 | + (name, param.to(device, non_blocking=True).full_tensor() if isinstance(param, DTensor) else param) |
| 407 | + for name, param in params.items() |
| 408 | + ) |
| 409 | + # TODO: support veomni LoRA |
| 410 | + return per_tensor_param, None |
350 | 411 |
|
351 | 412 |
|
352 | 413 | class EngineEvalModeCtx(BaseEngineCtx): |
|
0 commit comments