2929from torchtune .modules .peft import (
3030 AdapterModule ,
3131 get_adapter_params ,
32- get_adapter_state_dict ,
3332 get_lora_module_names ,
34- get_merged_lora_ckpt ,
3533 set_trainable_params ,
3634 validate_missing_and_unexpected_for_lora ,
3735)
4139 PROFILER_KEY ,
4240 VALID_BACKENDS_FOR_MEMORY_STATS ,
4341)
42+ from torchtune .training .checkpointing ._checkpoint_client import (
43+ CheckpointClient ,
44+ TrainingProgress ,
45+ )
4446from tqdm import tqdm
4547
4648
@@ -168,6 +170,9 @@ def __init__(self, cfg: DictConfig) -> None:
168170 )
169171 self ._log_peak_memory_stats = False
170172
173+ self ._enable_async_checkpointing = cfg .get ("enable_async_checkpointing" , False )
174+ self ._checkpoint_client = CheckpointClient (cfg )
175+
171176 # These attributes constitute the recipe state and are updated by ``load_checkpoint``
172177 # when ``resume_from_checkpoint`` is ``True``
173178 self .seed = training .set_seed (
@@ -215,31 +220,6 @@ def __init__(self, cfg: DictConfig) -> None:
215220 "Enabling activation offloading should reduce memory further." ,
216221 )
217222
218- def load_checkpoint (self , cfg_checkpointer : DictConfig ) -> Dict [str , Any ]:
219- """
220- Extract the checkpoint state from file and validate. This includes the
221- base model weights. If resume_from_checkpoint is True, this also includes
222- the adapter weights and recipe state
223- """
224- self ._checkpointer = config .instantiate (
225- cfg_checkpointer ,
226- should_load_recipe_state = self ._resume_from_checkpoint ,
227- )
228- checkpoint_dict = self ._checkpointer .load_checkpoint ()
229-
230- # When resuming from checkpoint for LoRA, the recipe expects the adapter weights
231- # and recipe state to be present. The keys should match up with what ``save_checkpoint``
232- # used to create these intermediate checkpoints
233- if self ._resume_from_checkpoint :
234- if training .ADAPTER_KEY not in checkpoint_dict :
235- raise ValueError (
236- "Adapter weights not found. Please ensure a valid adapter checkpoint is provided."
237- )
238- # _update_recipe_state will throw an exception if the recipe state is not corrctly loaded
239- # no need to check here
240- self ._update_recipe_state (checkpoint_dict )
241- return checkpoint_dict
242-
243223 def _update_recipe_state (self , ckpt_dict : Dict [str , Any ]) -> None :
244224 """
245225 Updates the recipe state from checkpoint.
@@ -299,7 +279,8 @@ def setup(self, cfg: DictConfig) -> None:
299279 "For Llama4 training, you should set save_adapter_weights_only to True."
300280 )
301281
302- checkpoint_dict = self .load_checkpoint (cfg_checkpointer = cfg .checkpointer )
282+ checkpoint_dict = self ._checkpoint_client .load_base_checkpoint ()
283+
303284 self ._compile = cfg .get ("compile" , False )
304285
305286 self ._model = self ._setup_model (
@@ -312,7 +293,7 @@ def setup(self, cfg: DictConfig) -> None:
312293 base_model_state_dict = checkpoint_dict [training .MODEL_KEY ],
313294 lora_weights_state_dict = (
314295 checkpoint_dict [training .ADAPTER_KEY ]
315- if self . _resume_from_checkpoint
296+ if training . ADAPTER_KEY in checkpoint_dict
316297 else None
317298 ),
318299 )
@@ -322,11 +303,38 @@ def setup(self, cfg: DictConfig) -> None:
322303 cfg_optimizer = cfg .optimizer ,
323304 opt_state_dict = (
324305 checkpoint_dict [training .OPT_KEY ]
325- if self . _resume_from_checkpoint
306+ if training . OPT_KEY in checkpoint_dict
326307 else None
327308 ),
328309 )
329310
311+ if self ._resume_from_checkpoint :
312+ # If async checkpointing is enabled, intermediate checkpoints are saved asynchronously
313+ # using the DistributedCheckpointer.
314+ # Therefore the recipe needs to load the distributed checkpoint to restore the training
315+ # progress.
316+ if self ._enable_async_checkpointing :
317+ try :
318+ checkpoint_dict = (
319+ self ._checkpoint_client .load_distributed_checkpoint (
320+ self ._model ,
321+ self ._optimizer ,
322+ self ._adapter_config ,
323+ )
324+ )
325+ except Exception as e :
326+ log .warning (
327+ f"Failed to load distributed checkpoint: { e } . Training will start from the base checkpoint."
328+ )
329+
330+ if training .ADAPTER_KEY not in checkpoint_dict :
331+ raise ValueError (
332+ "Adapter weights not found. Please ensure a valid adapter checkpoint is provided."
333+ )
334+
335+ # Update the recipe state from the checkpoint state dict.
336+ self ._update_recipe_state (checkpoint_dict )
337+
330338 # initialize loss
331339 self ._loss_fn = config .instantiate (cfg .loss )
332340 if isinstance (self ._loss_fn , SFTLoss ):
@@ -345,11 +353,6 @@ def setup(self, cfg: DictConfig) -> None:
345353 shuffle = cfg .shuffle ,
346354 batch_size = cfg .batch_size ,
347355 collate_fn = collate_name ,
348- dataloader_state_dict = (
349- checkpoint_dict [training .DATALOADER_KEY ]
350- if self ._resume_from_checkpoint
351- else None
352- ),
353356 )
354357
355358 # Setup validation dataloader if validation dataset is provided
@@ -450,6 +453,16 @@ def _setup_model(
450453 self ._lora_attn_modules = list (cfg_model .lora_attn_modules )
451454 self ._apply_lora_to_mlp = cfg_model .apply_lora_to_mlp
452455 self ._apply_lora_to_output = getattr (cfg_model , "apply_lora_to_output" , False )
456+ self ._adapter_config = {
457+ "r" : self ._lora_rank ,
458+ "lora_alpha" : self ._lora_alpha ,
459+ "target_modules" : get_lora_module_names (
460+ self ._lora_attn_modules ,
461+ self ._apply_lora_to_mlp ,
462+ self ._apply_lora_to_output ,
463+ ),
464+ "peft_type" : "LORA" ,
465+ }
453466
454467 utils .log_rank_zero (
455468 self ._logger ,
@@ -585,7 +598,6 @@ def _setup_data(
585598 shuffle : bool ,
586599 batch_size : int ,
587600 collate_fn : str ,
588- dataloader_state_dict : Optional [Dict [str , Any ]] = None ,
589601 ) -> StatefulDataLoader :
590602 """
591603 All data related setup happens here. This recipe currently supports only
@@ -637,115 +649,21 @@ def save_checkpoint(
637649 self ,
638650 epoch : int ,
639651 ) -> None :
640- """
641- Checkpoint the state of the recipe. The constructed checkpoint state dict
642- contains the following information:
643- - Merged weights with key MODEL_KEY
644- - Adapter weights with key ADAPTER_KEY
645- - Relevant recipe state if training is not complete
646- - If the `self._save_adapter_weights_only` option is True, the checkpointer will save only the adapter weights
647-
648- Checkpointer will save the merged weights, adapter weights and recipe state in
649- different checkpoint files. To correctly resume from training, the adapter weights
650- and recipe state must be provided along with the base model weights.
651- """
652- # final dict passed onto the checkpointer
653- checkpoint_dict = {}
654-
655- intermediate_checkpoint = epoch + 1 < self .total_epochs
656-
657- utils .log_rank_zero (
658- self ._logger ,
659- "Saving checkpoint. This may take some time. Retrieving full model state dict..." ,
660- )
661- start = time .perf_counter ()
662-
663- # To prevent GPU memory from spiking during checkpoint save,
664- # we consolidate the full model and optim state dicts on CPU for rank 0
665- cpu_state_dict = training .gather_cpu_state_dict (
666- self ._model ,
667- self ._is_rank_zero ,
668- device = self ._device ,
669- adapter_weights_only = self ._save_adapter_weights_only ,
670- )
671- utils .log_rank_zero (
672- self ._logger ,
673- f"Getting full model state dict took { time .perf_counter () - start :.2f} secs" ,
652+ self ._checkpoint_client .save_checkpoint (
653+ model = self ._model ,
654+ optimizer = self ._optimizer ,
655+ training_progress = TrainingProgress (
656+ seed = self .seed ,
657+ epochs_run = self .epochs_run ,
658+ total_epochs = self .total_epochs ,
659+ max_steps_per_epoch = self .max_steps_per_epoch ,
660+ dataloader_state_dict = self ._dataloader .state_dict (),
661+ ),
662+ epoch = epoch ,
663+ adapter_config = self ._adapter_config .copy (),
664+ adapter_only = self ._save_adapter_weights_only ,
674665 )
675666
676- if intermediate_checkpoint :
677- utils .log_rank_zero (self ._logger , "Retrieving optimizer state dict..." )
678- opt_state_dict = training .get_full_optimizer_state_dict (
679- self ._model ,
680- self ._optimizer ,
681- self ._is_rank_zero ,
682- device = self ._device ,
683- )
684- utils .log_rank_zero (
685- self ._logger ,
686- f"Getting optimizer state dict took { time .perf_counter () - start :.2f} secs" ,
687- )
688- else :
689- opt_state_dict = None
690-
691- # Now that we have the model and opt state dict, create the actual checkpoint dict
692- # to be sent to the checkpointer and ultimately written to file
693- if self ._is_rank_zero :
694- start = time .perf_counter ()
695-
696- if self ._save_adapter_weights_only :
697- adapter_state_dict = cpu_state_dict
698- else :
699- # Filter out the adapter keys and weights from the model state dict. These will
700- # be saved separately
701- adapter_state_dict = get_adapter_state_dict (cpu_state_dict )
702-
703- # merge the adapter weights and base weights to create the model checkpoint
704- merged_state_dict = get_merged_lora_ckpt (
705- cpu_state_dict ,
706- rank = self ._lora_rank ,
707- alpha = self ._lora_alpha ,
708- )
709- checkpoint_dict .update ({training .MODEL_KEY : merged_state_dict })
710- checkpoint_dict .update ({training .ADAPTER_KEY : adapter_state_dict })
711-
712- # if training is in-progress, checkpoint the optimizer state and recipe state
713- # as well.
714- if intermediate_checkpoint :
715- checkpoint_dict .update (
716- {
717- training .OPT_KEY : opt_state_dict ,
718- training .SEED_KEY : self .seed ,
719- training .EPOCHS_KEY : self .epochs_run ,
720- training .TOTAL_EPOCHS_KEY : self .total_epochs ,
721- training .MAX_STEPS_KEY : self .max_steps_per_epoch ,
722- training .DATALOADER_KEY : self ._dataloader .state_dict (),
723- }
724- )
725-
726- adapter_config = {
727- "r" : self ._lora_rank ,
728- "lora_alpha" : self ._lora_alpha ,
729- "target_modules" : get_lora_module_names (
730- self ._lora_attn_modules ,
731- self ._apply_lora_to_mlp ,
732- self ._apply_lora_to_output ,
733- ),
734- "peft_type" : "LORA" ,
735- }
736- checkpoint_dict .update ({training .ADAPTER_CONFIG : adapter_config })
737- self ._checkpointer .save_checkpoint (
738- checkpoint_dict ,
739- epoch = epoch ,
740- intermediate_checkpoint = intermediate_checkpoint ,
741- adapter_only = self ._save_adapter_weights_only ,
742- )
743- self ._logger .info (
744- f"Saving checkpoint took { time .perf_counter () - start :.2f} secs"
745- )
746-
747- torch .distributed .barrier ()
748-
749667 def train (self ) -> None :
750668 """
751669 The core training loop.
0 commit comments