@@ -371,6 +371,8 @@ def __init__(
371371 "Setting FSDP option to megatron"
372372 )
373373 fsdp = 'megatron'
374+ if self .save_ckpt_format != "fsdp_dtensor" :
375+ raise NotImplementedError (f"FSDP checkpointing is not supported with { self .save_ckpt_format } ." )
374376
375377 if fsdp == "pytorch" :
376378 raise NotImplementedError ("PyTorch FSDP2 is not supported with MegatronParallel." )
@@ -936,28 +938,24 @@ def _get_fsdp_dtensor_state_dict(
936938 model_key = "model" ,
937939 optimizer_key = "optimizer_states" ,
938940 ):
941+ from megatron .core .distributed .fsdp .src .megatron_fsdp .uneven_dtensor import (
942+ preprocess_state_dict_for_uneven_dtensor ,
943+ )
939944 from megatron .core .transformer .fsdp_dtensor_checkpoint import (
940945 handle_fp8_extra_state_case ,
941946 handle_swiglu_in_state_dict ,
942947 )
943- from megatron .core .distributed .fsdp .src .megatron_fsdp .uneven_dtensor import (
944- preprocess_state_dict_for_uneven_dtensor ,
945- )
946948
947949 state_dict = raw_state_dict .copy ()
948950 handle_fp8_extra_state_case (state_dict [model_key ])
949951 module = self .model [0 ].module
950- if torch .distributed .get_rank () == 0 :
951- print (self .model , module )
952952 if getattr (module .config , "gated_linear_unit" , False ):
953953 model_state_dict = state_dict [model_key ].copy ()
954954 if optimizer_key in state_dict :
955955 optimizer_state_dict = state_dict [optimizer_key ].copy ()
956956 else :
957957 optimizer_state_dict = {}
958- handle_swiglu_in_state_dict (
959- module .module , model_state_dict , optimizer_state_dict
960- )
958+ handle_swiglu_in_state_dict (module .module , model_state_dict , optimizer_state_dict )
961959 state_dict [model_key ] = model_state_dict
962960 if optimizer_key in state_dict :
963961 state_dict [optimizer_key ] = optimizer_state_dict
@@ -1060,9 +1058,7 @@ def _load_fsdp_dtensor_checkpoint(self, path, sharded_state_dict, strict):
10601058 checkpoint_id = path ,
10611059 planner = planner ,
10621060 )
1063- sharded_state_dict .update (
1064- self ._load_fsdp_dtensor_common_state (ckpt_dir = path )
1065- )
1061+ sharded_state_dict .update (self ._load_fsdp_dtensor_common_state (ckpt_dir = path ))
10661062 if "loops" in sharded_state_dict :
10671063 sharded_state_dict ["fit_loop" ] = sharded_state_dict ["loops" ]["fit_loop" ]
10681064
0 commit comments