Skip to content

Commit a6db644

Browse files
authored
Support async checkpointing with DCP for Lora finetune recipe (meta-pytorch#2705)
1 parent aa63d17 commit a6db644

File tree

2 files changed

+162
-143
lines changed

2 files changed

+162
-143
lines changed

recipes/lora_finetune_distributed.py

Lines changed: 61 additions & 143 deletions
Original file line numberDiff line numberDiff line change
@@ -29,9 +29,7 @@
2929
from torchtune.modules.peft import (
3030
AdapterModule,
3131
get_adapter_params,
32-
get_adapter_state_dict,
3332
get_lora_module_names,
34-
get_merged_lora_ckpt,
3533
set_trainable_params,
3634
validate_missing_and_unexpected_for_lora,
3735
)
@@ -41,6 +39,10 @@
4139
PROFILER_KEY,
4240
VALID_BACKENDS_FOR_MEMORY_STATS,
4341
)
42+
from torchtune.training.checkpointing._checkpoint_client import (
43+
CheckpointClient,
44+
TrainingProgress,
45+
)
4446
from tqdm import tqdm
4547

4648

@@ -168,6 +170,9 @@ def __init__(self, cfg: DictConfig) -> None:
168170
)
169171
self._log_peak_memory_stats = False
170172

173+
self._enable_async_checkpointing = cfg.get("enable_async_checkpointing", False)
174+
self._checkpoint_client = CheckpointClient(cfg)
175+
171176
# These attributes constitute the recipe state and are updated by ``load_checkpoint``
172177
# when ``resume_from_checkpoint`` is ``True``
173178
self.seed = training.set_seed(
@@ -215,31 +220,6 @@ def __init__(self, cfg: DictConfig) -> None:
215220
"Enabling activation offloading should reduce memory further.",
216221
)
217222

218-
def load_checkpoint(self, cfg_checkpointer: DictConfig) -> Dict[str, Any]:
219-
"""
220-
Extract the checkpoint state from file and validate. This includes the
221-
base model weights. If resume_from_checkpoint is True, this also includes
222-
the adapter weights and recipe state
223-
"""
224-
self._checkpointer = config.instantiate(
225-
cfg_checkpointer,
226-
should_load_recipe_state=self._resume_from_checkpoint,
227-
)
228-
checkpoint_dict = self._checkpointer.load_checkpoint()
229-
230-
# When resuming from checkpoint for LoRA, the recipe expects the adapter weights
231-
# and recipe state to be present. The keys should match up with what ``save_checkpoint``
232-
# used to create these intermediate checkpoints
233-
if self._resume_from_checkpoint:
234-
if training.ADAPTER_KEY not in checkpoint_dict:
235-
raise ValueError(
236-
"Adapter weights not found. Please ensure a valid adapter checkpoint is provided."
237-
)
238-
# _update_recipe_state will throw an exception if the recipe state is not corrctly loaded
239-
# no need to check here
240-
self._update_recipe_state(checkpoint_dict)
241-
return checkpoint_dict
242-
243223
def _update_recipe_state(self, ckpt_dict: Dict[str, Any]) -> None:
244224
"""
245225
Updates the recipe state from checkpoint.
@@ -299,7 +279,8 @@ def setup(self, cfg: DictConfig) -> None:
299279
"For Llama4 training, you should set save_adapter_weights_only to True."
300280
)
301281

302-
checkpoint_dict = self.load_checkpoint(cfg_checkpointer=cfg.checkpointer)
282+
checkpoint_dict = self._checkpoint_client.load_base_checkpoint()
283+
303284
self._compile = cfg.get("compile", False)
304285

305286
self._model = self._setup_model(
@@ -312,7 +293,7 @@ def setup(self, cfg: DictConfig) -> None:
312293
base_model_state_dict=checkpoint_dict[training.MODEL_KEY],
313294
lora_weights_state_dict=(
314295
checkpoint_dict[training.ADAPTER_KEY]
315-
if self._resume_from_checkpoint
296+
if training.ADAPTER_KEY in checkpoint_dict
316297
else None
317298
),
318299
)
@@ -322,11 +303,38 @@ def setup(self, cfg: DictConfig) -> None:
322303
cfg_optimizer=cfg.optimizer,
323304
opt_state_dict=(
324305
checkpoint_dict[training.OPT_KEY]
325-
if self._resume_from_checkpoint
306+
if training.OPT_KEY in checkpoint_dict
326307
else None
327308
),
328309
)
329310

311+
if self._resume_from_checkpoint:
312+
# If async checkpointing is enabled, intermediate checkpoints are saved asynchronously
313+
# using the DistributedCheckpointer.
314+
# Therefore the recipe needs to load the distributed checkpoint to restore the training
315+
# progress.
316+
if self._enable_async_checkpointing:
317+
try:
318+
checkpoint_dict = (
319+
self._checkpoint_client.load_distributed_checkpoint(
320+
self._model,
321+
self._optimizer,
322+
self._adapter_config,
323+
)
324+
)
325+
except Exception as e:
326+
log.warning(
327+
f"Failed to load distributed checkpoint: {e}. Training will start from the base checkpoint."
328+
)
329+
330+
if training.ADAPTER_KEY not in checkpoint_dict:
331+
raise ValueError(
332+
"Adapter weights not found. Please ensure a valid adapter checkpoint is provided."
333+
)
334+
335+
# Update the recipe state from the checkpoint state dict.
336+
self._update_recipe_state(checkpoint_dict)
337+
330338
# initialize loss
331339
self._loss_fn = config.instantiate(cfg.loss)
332340
if isinstance(self._loss_fn, SFTLoss):
@@ -345,11 +353,6 @@ def setup(self, cfg: DictConfig) -> None:
345353
shuffle=cfg.shuffle,
346354
batch_size=cfg.batch_size,
347355
collate_fn=collate_name,
348-
dataloader_state_dict=(
349-
checkpoint_dict[training.DATALOADER_KEY]
350-
if self._resume_from_checkpoint
351-
else None
352-
),
353356
)
354357

355358
# Setup validation dataloader if validation dataset is provided
@@ -450,6 +453,16 @@ def _setup_model(
450453
self._lora_attn_modules = list(cfg_model.lora_attn_modules)
451454
self._apply_lora_to_mlp = cfg_model.apply_lora_to_mlp
452455
self._apply_lora_to_output = getattr(cfg_model, "apply_lora_to_output", False)
456+
self._adapter_config = {
457+
"r": self._lora_rank,
458+
"lora_alpha": self._lora_alpha,
459+
"target_modules": get_lora_module_names(
460+
self._lora_attn_modules,
461+
self._apply_lora_to_mlp,
462+
self._apply_lora_to_output,
463+
),
464+
"peft_type": "LORA",
465+
}
453466

454467
utils.log_rank_zero(
455468
self._logger,
@@ -585,7 +598,6 @@ def _setup_data(
585598
shuffle: bool,
586599
batch_size: int,
587600
collate_fn: str,
588-
dataloader_state_dict: Optional[Dict[str, Any]] = None,
589601
) -> StatefulDataLoader:
590602
"""
591603
All data related setup happens here. This recipe currently supports only
@@ -637,115 +649,21 @@ def save_checkpoint(
637649
self,
638650
epoch: int,
639651
) -> None:
640-
"""
641-
Checkpoint the state of the recipe. The constructed checkpoint state dict
642-
contains the following information:
643-
- Merged weights with key MODEL_KEY
644-
- Adapter weights with key ADAPTER_KEY
645-
- Relevant recipe state if training is not complete
646-
- If the `self._save_adapter_weights_only` option is True, the checkpointer will save only the adapter weights
647-
648-
Checkpointer will save the merged weights, adapter weights and recipe state in
649-
different checkpoint files. To correctly resume from training, the adapter weights
650-
and recipe state must be provided along with the base model weights.
651-
"""
652-
# final dict passed onto the checkpointer
653-
checkpoint_dict = {}
654-
655-
intermediate_checkpoint = epoch + 1 < self.total_epochs
656-
657-
utils.log_rank_zero(
658-
self._logger,
659-
"Saving checkpoint. This may take some time. Retrieving full model state dict...",
660-
)
661-
start = time.perf_counter()
662-
663-
# To prevent GPU memory from spiking during checkpoint save,
664-
# we consolidate the full model and optim state dicts on CPU for rank 0
665-
cpu_state_dict = training.gather_cpu_state_dict(
666-
self._model,
667-
self._is_rank_zero,
668-
device=self._device,
669-
adapter_weights_only=self._save_adapter_weights_only,
670-
)
671-
utils.log_rank_zero(
672-
self._logger,
673-
f"Getting full model state dict took {time.perf_counter() - start:.2f} secs",
652+
self._checkpoint_client.save_checkpoint(
653+
model=self._model,
654+
optimizer=self._optimizer,
655+
training_progress=TrainingProgress(
656+
seed=self.seed,
657+
epochs_run=self.epochs_run,
658+
total_epochs=self.total_epochs,
659+
max_steps_per_epoch=self.max_steps_per_epoch,
660+
dataloader_state_dict=self._dataloader.state_dict(),
661+
),
662+
epoch=epoch,
663+
adapter_config=self._adapter_config.copy(),
664+
adapter_only=self._save_adapter_weights_only,
674665
)
675666

676-
if intermediate_checkpoint:
677-
utils.log_rank_zero(self._logger, "Retrieving optimizer state dict...")
678-
opt_state_dict = training.get_full_optimizer_state_dict(
679-
self._model,
680-
self._optimizer,
681-
self._is_rank_zero,
682-
device=self._device,
683-
)
684-
utils.log_rank_zero(
685-
self._logger,
686-
f"Getting optimizer state dict took {time.perf_counter() - start:.2f} secs",
687-
)
688-
else:
689-
opt_state_dict = None
690-
691-
# Now that we have the model and opt state dict, create the actual checkpoint dict
692-
# to be sent to the checkpointer and ultimately written to file
693-
if self._is_rank_zero:
694-
start = time.perf_counter()
695-
696-
if self._save_adapter_weights_only:
697-
adapter_state_dict = cpu_state_dict
698-
else:
699-
# Filter out the adapter keys and weights from the model state dict. These will
700-
# be saved separately
701-
adapter_state_dict = get_adapter_state_dict(cpu_state_dict)
702-
703-
# merge the adapter weights and base weights to create the model checkpoint
704-
merged_state_dict = get_merged_lora_ckpt(
705-
cpu_state_dict,
706-
rank=self._lora_rank,
707-
alpha=self._lora_alpha,
708-
)
709-
checkpoint_dict.update({training.MODEL_KEY: merged_state_dict})
710-
checkpoint_dict.update({training.ADAPTER_KEY: adapter_state_dict})
711-
712-
# if training is in-progress, checkpoint the optimizer state and recipe state
713-
# as well.
714-
if intermediate_checkpoint:
715-
checkpoint_dict.update(
716-
{
717-
training.OPT_KEY: opt_state_dict,
718-
training.SEED_KEY: self.seed,
719-
training.EPOCHS_KEY: self.epochs_run,
720-
training.TOTAL_EPOCHS_KEY: self.total_epochs,
721-
training.MAX_STEPS_KEY: self.max_steps_per_epoch,
722-
training.DATALOADER_KEY: self._dataloader.state_dict(),
723-
}
724-
)
725-
726-
adapter_config = {
727-
"r": self._lora_rank,
728-
"lora_alpha": self._lora_alpha,
729-
"target_modules": get_lora_module_names(
730-
self._lora_attn_modules,
731-
self._apply_lora_to_mlp,
732-
self._apply_lora_to_output,
733-
),
734-
"peft_type": "LORA",
735-
}
736-
checkpoint_dict.update({training.ADAPTER_CONFIG: adapter_config})
737-
self._checkpointer.save_checkpoint(
738-
checkpoint_dict,
739-
epoch=epoch,
740-
intermediate_checkpoint=intermediate_checkpoint,
741-
adapter_only=self._save_adapter_weights_only,
742-
)
743-
self._logger.info(
744-
f"Saving checkpoint took {time.perf_counter() - start:.2f} secs"
745-
)
746-
747-
torch.distributed.barrier()
748-
749667
def train(self) -> None:
750668
"""
751669
The core training loop.

tests/recipes/test_lora_finetune_distributed.py

Lines changed: 101 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -212,6 +212,107 @@ def test_training_state_on_resume(
212212
loss_values, expected_loss_values, rtol=1e-5, atol=1e-5
213213
)
214214

215+
@pytest.mark.integration_test
216+
@gpu_test(gpu_count=2)
217+
@pytest.mark.parametrize(
218+
"config, model_type, ckpt_type, save_adapter_weights_only",
219+
[
220+
("llama2/7B_lora", "llama2", "hf", False),
221+
("llama3/8B_lora", "llama3", "tune", False),
222+
("llama2/7B_lora", "llama2", "hf", True),
223+
],
224+
)
225+
def test_training_state_on_resume_with_async_checkpointing(
226+
self,
227+
config,
228+
model_type,
229+
ckpt_type,
230+
tmpdir,
231+
monkeypatch,
232+
save_adapter_weights_only,
233+
):
234+
"""Test whether the recipe state is correctly updated on resume. Since this
235+
is model agnostic, we should run this on the small model only. The test
236+
consists of three stages:
237+
- Train a model for 2 epochs
238+
- Resume training after epoch 1
239+
- Make sure final loss matches the expected value of a model successfully resumed from a ckpt
240+
"""
241+
ckpt_component = CKPT_COMPONENT_MAP[ckpt_type]
242+
ckpt = model_type + "_" + ckpt_type
243+
244+
ckpt_path = Path(CKPT_MODEL_PATHS[ckpt])
245+
tokenizer_path = Path(TOKENIZER_PATHS[model_type])
246+
ckpt_dir = ckpt_path.parent
247+
log_file = gen_log_file_name(tmpdir)
248+
249+
# Config file needed for model conversion.
250+
# Create a second copy for training resume
251+
write_hf_ckpt_config(ckpt_dir)
252+
write_hf_ckpt_config(tmpdir)
253+
254+
# Train for two epochs
255+
cmd_1 = f"""
256+
tune run --nnodes 1 --nproc_per_node 2 lora_finetune_distributed \
257+
--config {config} \
258+
batch_size=4 \
259+
gradient_accumulation_steps=1 \
260+
output_dir={tmpdir} \
261+
model.lora_attn_modules=['q_proj','v_proj'] \
262+
model.apply_lora_to_mlp=False \
263+
checkpointer._component_={ckpt_component} \
264+
checkpointer.checkpoint_dir='{ckpt_dir}' \
265+
checkpointer.checkpoint_files=[{ckpt_path}]\
266+
checkpointer.output_dir={tmpdir} \
267+
checkpointer.model_type={model_type.upper()} \
268+
tokenizer.path='{tokenizer_path}' \
269+
tokenizer.prompt_template=null \
270+
save_adapter_weights_only={save_adapter_weights_only} \
271+
enable_activation_checkpointing=True \
272+
enable_activation_offloading=True \
273+
enable_async_checkpointing=True \
274+
""".split()
275+
276+
model_config = MODEL_TEST_CONFIGS[model_type + "_lora"]
277+
278+
cmd_1 = cmd_1 + self._get_test_config_overrides() + model_config
279+
monkeypatch.setattr(sys, "argv", cmd_1)
280+
runpy.run_path(TUNE_PATH, run_name="__main__")
281+
282+
# Resume training
283+
cmd_2 = f"""
284+
tune run --nnodes 1 --nproc_per_node 2 lora_finetune_distributed \
285+
--config {config} \
286+
batch_size=4 \
287+
gradient_accumulation_steps=1 \
288+
output_dir={tmpdir} \
289+
model.lora_attn_modules=['q_proj','v_proj'] \
290+
model.apply_lora_to_mlp=False \
291+
checkpointer._component_={ckpt_component} \
292+
checkpointer.checkpoint_dir={ckpt_dir} \
293+
checkpointer.checkpoint_files=[{ckpt_path}]\
294+
checkpointer.output_dir={tmpdir} \
295+
checkpointer.model_type={model_type.upper()} \
296+
tokenizer.path='{tokenizer_path}' \
297+
tokenizer.prompt_template=null \
298+
resume_from_checkpoint=True \
299+
metric_logger.filename={log_file} \
300+
enable_activation_checkpointing=True \
301+
enable_activation_offloading=True \
302+
enable_async_checkpointing=True \
303+
""".split()
304+
305+
cmd_2 = cmd_2 + self._get_test_config_overrides() + model_config
306+
monkeypatch.setattr(sys, "argv", cmd_2)
307+
runpy.run_path(TUNE_PATH, run_name="__main__")
308+
309+
expected_loss_values = self._fetch_expected_loss_values(model_type)[2:]
310+
311+
loss_values = get_loss_values_from_metric_logger(log_file)
312+
torch.testing.assert_close(
313+
loss_values, expected_loss_values, rtol=1e-5, atol=1e-5
314+
)
315+
215316
@pytest.mark.integration_test
216317
@pytest.mark.parametrize(
217318
"recipe_config, model_type, ckpt_type, use_dora",

0 commit comments

Comments
 (0)