@@ -557,12 +557,24 @@ def _iter_opts(opt):
557557 offload_megatron_copy_params (_opt )
558558 ## worker may hold zero parameter when enabling custom pipeline layout
559559 if _opt .optimizer is not None :
560- opt_state_dict_values = _opt .optimizer .state .values ()
561- for v in opt_state_dict_values :
562- if "exp_avg" in v :
563- v ["exp_avg" ] = v ["exp_avg" ].to ("cpu" , non_blocking = True )
564- if "exp_avg_sq" in v :
565- v ["exp_avg_sq" ] = v ["exp_avg_sq" ].to ("cpu" , non_blocking = True )
560+ # HybridDeviceOptimizer: offload all sub-optimizer states to CPU
561+ # TODO: this should be a method in Megatron-LM's HybridDeviceOptimizer
562+ hdo = _opt .optimizer
563+ if all (hasattr (hdo , attr ) for attr in ("sub_optimizers" , "inner_param_to_orig_param" , "state" )):
564+ for optimizer in hdo .sub_optimizers :
565+ for param , state in optimizer .state .items ():
566+ for k , v in state .items ():
567+ if not isinstance (v , torch .Tensor ):
568+ continue
569+ orig_param = hdo .inner_param_to_orig_param .get (param , param )
570+ hdo .state [orig_param ][k ] = state [k ] = v .to ("cpu" )
571+ else :
572+ opt_state_dict_values = _opt .optimizer .state .values ()
573+ for v in opt_state_dict_values :
574+ if "exp_avg" in v :
575+ v ["exp_avg" ] = v ["exp_avg" ].to ("cpu" , non_blocking = True )
576+ if "exp_avg_sq" in v :
577+ v ["exp_avg_sq" ] = v ["exp_avg_sq" ].to ("cpu" , non_blocking = True )
566578 gc .collect ()
567579 get_torch_device ().empty_cache ()
568580
0 commit comments