Skip to content

Commit 5818bb8

Browse files
LucasLLCfacebook-github-bot
authored andcommitted
wraps DDP models with DSD
Summary: Distributed State Dict is the current suggested way from PyTorch for ensuring parallelized models state dicts are compatible with save/loads in Single process or re-sharding scenarios. This diff updates dcp_saver to use DSD for DDP models. A good idea would be wrap all models in TNT with DSD, as this could replace some of the wrapper logic for FSDP and would guarantee future compat. N5551629 also contains a workaround for current DDP model saved before this diff, by manually removing the "module." prefix in the checkpoint. Differential Revision: D59234083
1 parent 5dad8d3 commit 5818bb8

File tree

1 file changed

+33
-2
lines changed

1 file changed

+33
-2
lines changed

torchtnt/framework/callbacks/dcp_saver.py

Lines changed: 33 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -19,8 +19,12 @@
1919
DefaultSavePlanner,
2020
)
2121
from torch.distributed.checkpoint.planner import LoadPlanner, SavePlanner
22+
from torch.distributed.checkpoint.state_dict import (
23+
get_model_state_dict,
24+
set_model_state_dict,
25+
)
2226
from torch.distributed.checkpoint.storage import StorageReader, StorageWriter
23-
27+
from torch.nn.parallel import DistributedDataParallel
2428
from torchtnt.framework.callbacks._checkpoint_utils import (
2529
_prepare_app_state_for_checkpoint,
2630
_prepare_app_state_for_restore,
@@ -41,6 +45,7 @@
4145
from torchtnt.utils.checkpoint import BestCheckpointConfig, CheckpointPath
4246
from torchtnt.utils.optimizer import init_optim_state
4347
from torchtnt.utils.rank_zero_log import rank_zero_info, rank_zero_warn
48+
4449
from torchtnt.utils.stateful import MultiStateful, Stateful
4550

4651

@@ -63,6 +68,24 @@
6368
)
6469

6570

71+
class DSDModelWrapper(Stateful):
72+
"""This wrapper converts state dicts to Distributed State Dicts, essentially generating
73+
state dicts as if they were created using single-device methods. This is useful for
74+
when checkpoint models might be resharded, or loaded in notebooks or otherwise non-distributed
75+
settings.
76+
77+
"""
78+
79+
def __init__(self, mod: torch.nn.Module) -> None:
80+
self.mod: torch.nn.Module = mod
81+
82+
def state_dict(self) -> Dict[str, Any]:
83+
return get_model_state_dict(self.mod)
84+
85+
def load_state_dict(self, state_dict: Dict[str, Any]) -> None:
86+
set_model_state_dict(self.mod, state_dict)
87+
88+
6689
class DistributedCheckpointSaver(BaseCheckpointer):
6790
"""
6891
A callback which periodically saves the application state during training using `Distributed Checkpoint <https://pytorch.org/docs/stable/distributed.checkpoint.html>`_.
@@ -148,6 +171,11 @@ def _checkpoint_impl(
148171
curr_snapshot_wait = hook == "on_train_end"
149172

150173
app_state = _prepare_app_state_for_checkpoint(state, unit, intra_epoch)
174+
175+
for key, obj in app_state.items():
176+
if isinstance(obj, DistributedDataParallel):
177+
app_state[key] = DSDModelWrapper(obj)
178+
151179
# TODO: evaluate whether we need to implement the equivalent of torchsnapshot.RNGState()
152180
if self._async_checkpoint:
153181
with get_timing_context(state, f"{self.__class__.__name__}.async_save"):
@@ -315,14 +343,17 @@ def restore(
315343
)
316344

317345
# necessary for loading optimizers since states are initialized lazy
318-
for obj in app_state.values():
346+
for key, obj in app_state.items():
319347
# sometimes optimizers are actually held in a wrapper which handles calling
320348
# state_dict and load_state_dict, sa is the case for
321349
# `torchtnt.utils.prepare_module.FSDPOptimizerWrapper`, this handles that case.
322350
optimizer = getattr(obj, "optimizer", obj)
323351
if isinstance(optimizer, torch.optim.Optimizer):
324352
init_optim_state(optimizer)
325353

354+
if isinstance(obj, DistributedDataParallel):
355+
app_state[key] = DSDModelWrapper(obj)
356+
326357
try:
327358
dcp.load(
328359
{"app_state": MultiStateful(app_state)},

0 commit comments

Comments
 (0)