Skip to content

Commit 27a9ede

Browse files
authored
[ckpt] fix: properly handle optimizer offloading for HybridDeviceOptimizer (#4870)
1 parent 1fa9131 commit 27a9ede

File tree

1 file changed

+18
-6
lines changed

1 file changed

+18
-6
lines changed

verl/utils/megatron_utils.py

Lines changed: 18 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -557,12 +557,24 @@ def _iter_opts(opt):
557557
offload_megatron_copy_params(_opt)
558558
## worker may hold zero parameter when enabling custom pipeline layout
559559
if _opt.optimizer is not None:
560-
opt_state_dict_values = _opt.optimizer.state.values()
561-
for v in opt_state_dict_values:
562-
if "exp_avg" in v:
563-
v["exp_avg"] = v["exp_avg"].to("cpu", non_blocking=True)
564-
if "exp_avg_sq" in v:
565-
v["exp_avg_sq"] = v["exp_avg_sq"].to("cpu", non_blocking=True)
560+
# HybridDeviceOptimizer: offload all sub-optimizer states to CPU
561+
# TODO: this should be a method in Megatron-LM's HybridDeviceOptimizer
562+
hdo = _opt.optimizer
563+
if all(hasattr(hdo, attr) for attr in ("sub_optimizers", "inner_param_to_orig_param", "state")):
564+
for optimizer in hdo.sub_optimizers:
565+
for param, state in optimizer.state.items():
566+
for k, v in state.items():
567+
if not isinstance(v, torch.Tensor):
568+
continue
569+
orig_param = hdo.inner_param_to_orig_param.get(param, param)
570+
hdo.state[orig_param][k] = state[k] = v.to("cpu")
571+
else:
572+
opt_state_dict_values = _opt.optimizer.state.values()
573+
for v in opt_state_dict_values:
574+
if "exp_avg" in v:
575+
v["exp_avg"] = v["exp_avg"].to("cpu", non_blocking=True)
576+
if "exp_avg_sq" in v:
577+
v["exp_avg_sq"] = v["exp_avg_sq"].to("cpu", non_blocking=True)
566578
gc.collect()
567579
get_torch_device().empty_cache()
568580

0 commit comments

Comments
 (0)