Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 0 additions & 1 deletion nemo/collections/diffusion/recipes/flux_12b.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,7 +113,6 @@ def trainer(
gradient_accumulation_fusion=True,
ddp=run.Config(
DistributedDataParallelConfig,
# use_custom_fsdp=True,
# data_parallel_sharding_strategy='optim_grads_params',
check_for_nan_in_grad=True,
grad_reduce_in_fp32=True,
Expand Down
9 changes: 9 additions & 0 deletions nemo/lightning/_strategy_lib.py
Original file line number Diff line number Diff line change
Expand Up @@ -574,6 +574,13 @@ def load_model_state_dict(megatron_parallel, checkpoint: Mapping[str, Any], stri
except ImportError or ModuleNotFoundError:
have_custom_fsdp = False

try:
from megatron.core.distributed import FullyShardedDataParallel

have_megatron_fsdp = True
except ImportError or ModuleNotFoundError:
have_megatron_fsdp = False

for index, module in enumerate(megatron_parallel):
if parallel_state.get_virtual_pipeline_model_parallel_world_size() is not None:
if "state_dict" in checkpoint:
Expand Down Expand Up @@ -612,6 +619,8 @@ def load_model_state_dict(megatron_parallel, checkpoint: Mapping[str, Any], stri

if have_custom_fsdp and hasattr(module, "module") and isinstance(module.module, FullyShardedDataParallel):
module.module.load_state_dict(_state_dict, strict=strict)
elif have_megatron_fsdp and hasattr(module, "module") and isinstance(module.module, FullyShardedDataParallel):
module.module.load_state_dict(_state_dict, strict=strict)
continue

try:
Expand Down
31 changes: 27 additions & 4 deletions nemo/lightning/megatron_parallel.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,13 @@
except ImportError:
HAVE_CUSTOM_FSDP = False

try:
from megatron.core.distributed import FullyShardedDataParallel

HAVE_MEGATRON_FSDP = True
except ImportError:
HAVE_MEGATRON_FSDP = False

try:
from megatron.core.full_cuda_graph import FullCudaGraphWrapper

Expand Down Expand Up @@ -553,7 +560,7 @@ def init_model_parallel(self):
from megatron.core.tensor_parallel.layers import set_defaults_if_not_set_tensor_model_parallel_attributes

for model_module in self:
if not self._cpu and (not HAVE_CUSTOM_FSDP or self.fsdp != "megatron"):
if not self._cpu and ((not HAVE_MEGATRON_FSDP and not HAVE_CUSTOM_FSDP) or self.fsdp != "megatron"):
# If Megatron custom FSDP is enabled, we don't need to move the model to GPU here to avoid GPU OOM.
model_module.cuda(torch.cuda.current_device())

Expand Down Expand Up @@ -634,12 +641,16 @@ def init_ddp(self):
# Avoid rewrapping the module if it's already wrapped with FSDP
unwrapped_module = unwrap_model(module, Float16Module)
if (
HAVE_CUSTOM_FSDP
(HAVE_MEGATRON_FSDP or HAVE_CUSTOM_FSDP)
and self.fsdp == "megatron"
and not isinstance(unwrapped_module, FullyShardedDataParallel)
):
from nemo.utils import logging

if not getattr(module.config, "use_megatron_fsdp", False):
setattr(module.config, "use_megatron_fsdp", True)
logging.warning("Setting module.config.use_megatron_fsdp to True for MCore FSDP.")

if not getattr(module.config, "use_custom_fsdp", False):
setattr(module.config, "use_custom_fsdp", True)
logging.warning("Setting module.config.use_custom_fsdp to True for MCore FSDP.")
Expand All @@ -648,15 +659,27 @@ def init_ddp(self):
setattr(module.config, "gradient_accumulation_fusion", False)
logging.warning("Setting module.config.gradient_accumulation_fusion to False for MCore FSDP.")

assert module.config.use_custom_fsdp, "Custom FSDP is not enabled in module.config."
assert self.ddp_config.use_custom_fsdp, "Custom FSDP is not enabled in ddp_config."
if HAVE_MEGATRON_FSDP:
assert module.config.use_megatron_fsdp, "MCore FSDP is not enabled in module.config."
assert self.ddp_config.use_megatron_fsdp, "MCore FSDP is not enabled in ddp_config."
elif HAVE_CUSTOM_FSDP:
assert module.config.use_custom_fsdp, "MCore FSDP is not enabled in module.config."
assert self.ddp_config.use_custom_fsdp, "MCore FSDP is not enabled in ddp_config."
logging.warning(
"Deprecation Notice: `use_custom_fsdp` will be deprecated in M-Core 0.14. "
"Please use `use_megatron_fsdp` instead."
)

dist_module = FullyShardedDataParallel(
module.config,
self.ddp_config,
module,
disable_bucketing=disable_bucketing,
)
if HAVE_MEGATRON_FSDP:
dist_module.buffers = [dist_module.param_and_grad_buffer]
dist_module.config = module.config
dist_module.sharded_state_dict = lambda *args, **kwargs: dist_module.state_dict()
elif not isinstance(unwrapped_module, DDP):
dist_module = DDP(
module.config,
Expand Down
141 changes: 131 additions & 10 deletions nemo/lightning/pytorch/strategies/megatron_strategy.py
Original file line number Diff line number Diff line change
Expand Up @@ -365,20 +365,33 @@

self._fsdp = None

if fsdp is None and self.ddp_config and self.ddp_config.use_custom_fsdp:
use_custom_fsdp = getattr(self.ddp_config, "use_custom_fsdp", False)
use_megatron_fsdp = getattr(self.ddp_config, "use_megatron_fsdp", False)
if fsdp is None and self.ddp_config and (use_custom_fsdp or use_megatron_fsdp):
logging.warning(
"FSDP option is not set but ddp_config.use_custom_fsdp is set to true. "
"FSDP option is not set but ddp_config use megatron-fsdp is set to true. "
"Setting FSDP option to megatron"
)
fsdp = 'megatron'
if use_megatron_fsdp and self.save_ckpt_format != "fsdp_dtensor":
raise NotImplementedError(
f"Megatron-FSDP checkpointing is not supported with {self.save_ckpt_format}."
)

if fsdp == "pytorch":
raise NotImplementedError("PyTorch FSDP2 is not supported with MegatronParallel.")
elif fsdp == "megatron":
self._fsdp = fsdp
if not self.ddp_config.use_custom_fsdp:
if hasattr(self.ddp_config, "use_custom_fsdp") and not use_custom_fsdp:
self.ddp_config.use_custom_fsdp = True
logging.warning("Setting ddp_config.use_custom_fsdp to True for MCore FSDP.")
logging.warning(
"Deprecation Notice: `use_custom_fsdp` will be deprecated in M-Core 0.14. "
"Please use `use_megatron_fsdp` instead."
)
elif hasattr(self.ddp_config, "use_megatron_fsdp") and not use_megatron_fsdp:
self.ddp_config.use_megatron_fsdp = True
logging.warning("Setting ddp_config.use_megatron_fsdp to True for MCore FSDP.")
logging.info("FSDP option is set to MCore. Using MCore's Custom FSDP for DP.")
elif fsdp is not None:
raise ValueError(f'Invalid DDP type: {fsdp}, please choose from ["megatron", "pytorch"].')
Expand Down Expand Up @@ -930,6 +943,54 @@
metadata=metadata,
)

def _get_fsdp_dtensor_state_dict(
self,
raw_state_dict,
model_key="model",
optimizer_key="optimizer_states",
):
from megatron.core.distributed.fsdp.src.megatron_fsdp.uneven_dtensor import (
preprocess_state_dict_for_uneven_dtensor,
)
from megatron.core.transformer.fsdp_dtensor_checkpoint import (
handle_fp8_extra_state_case,
handle_swiglu_in_state_dict,
)

state_dict = raw_state_dict.copy()
handle_fp8_extra_state_case(state_dict[model_key])
module = self.model[0].module
if getattr(module.config, "gated_linear_unit", False):
model_state_dict = state_dict[model_key].copy()
if optimizer_key in state_dict:
optimizer_state_dict = state_dict[optimizer_key].copy()
else:
optimizer_state_dict = {}

Check notice

Code scanning / CodeQL

Cyclic import Note

Import of module
nemo.collections.llm.modelopt.model_utils
begins an import cycle.
handle_swiglu_in_state_dict(module.module, model_state_dict, optimizer_state_dict)
state_dict[model_key] = model_state_dict
if optimizer_key in state_dict:
state_dict[optimizer_key] = optimizer_state_dict
preprocess_state_dict_for_uneven_dtensor(state_dict)

return state_dict

def _save_fsdp_dtensor_checkpoint(
self,
checkpoint: Dict[str, Any],
path,
storage_options,
):
state_dict = self._get_fsdp_dtensor_state_dict(checkpoint)

torch.distributed.checkpoint.save(
state_dict,
storage_writer=torch.distributed.checkpoint.FileSystemWriter(path),
)
self._save_fsdp_dtensor_common_state(state_dict=state_dict, ckpt_dir=path)

if "finalize_fn" in storage_options:
storage_options["finalize_fn"]()

@override
def save_checkpoint(
self, checkpoint: Dict[str, Any], filepath: Union[str, Path], storage_options: Optional[Any] = None
Expand Down Expand Up @@ -963,10 +1024,23 @@
if not storage_options:
storage_options = {}
storage_options['content_metadata'] = self.sharded_state_dict_metadata
self.checkpoint_io.save_checkpoint(checkpoint, filepath, storage_options=storage_options)
if self.save_ckpt_format == "fsdp_dtensor":
checkpoint = checkpoint.copy()
if "optimizer" in checkpoint:
checkpoint["optimizer_states"] = checkpoint.pop("optimizer")[0]
checkpoint["model"] = checkpoint.pop("sharded_state_dict")
self._save_fsdp_dtensor_checkpoint(
checkpoint=checkpoint,
path=ckpt_to_dir(filepath),
storage_options=storage_options,
)
checkpoint_io = None
else:
self.checkpoint_io.save_checkpoint(checkpoint, filepath, storage_options=storage_options)
checkpoint_io = self.checkpoint_io

# Save ModelOpt state too, if it exists.
save_modelopt_state(self.megatron_parallel, filepath, self.checkpoint_io)
save_modelopt_state(self.megatron_parallel, filepath, checkpoint_io)

def should_restore_optimizer_states(self, selective_restore: bool = False) -> bool:
"""Determines whether to restore optimizer states or not"""
Expand All @@ -975,6 +1049,32 @@

return self.ckpt_load_optimizer

def _save_fsdp_dtensor_common_state(self, state_dict, ckpt_dir):
state_dict = state_dict.copy()
del state_dict["model"]
del state_dict["optimizer_states"]
torch.save(state_dict, os.path.join(ckpt_dir, "common.pt"))

def _load_fsdp_dtensor_common_state(self, ckpt_dir):
return torch.load(os.path.join(ckpt_dir, "common.pt"), weights_only=False)

def _load_fsdp_dtensor_checkpoint(self, path, sharded_state_dict, strict):
from torch.distributed.checkpoint import default_planner

state_dict = self._get_fsdp_dtensor_state_dict(sharded_state_dict)

planner = default_planner.DefaultLoadPlanner(allow_partial_load=not strict)
torch.distributed.checkpoint.load(
state_dict,
checkpoint_id=path,
planner=planner,
)
sharded_state_dict.update(self._load_fsdp_dtensor_common_state(ckpt_dir=path))
if "loops" in sharded_state_dict:
sharded_state_dict["fit_loop"] = sharded_state_dict["loops"]["fit_loop"]

return sharded_state_dict

@override
def load_checkpoint(self, checkpoint_path: Union[str, Path], selective_restore: bool = False) -> Dict[str, Any]:
"""PTL method which we override to integrate distributed checkpoints for model parallel models.
Expand All @@ -994,7 +1094,10 @@
sharded_state_context = nullcontext

# After dist_checkpointing.load, sharded tensors will be replaced with tensors
sharded_sd_metadata = self.unwrapped_checkpoint_io.load_content_metadata(checkpoint_path)
if self.save_ckpt_format == "fsdp_dtensor":
sharded_sd_metadata = self.sharded_state_dict_metadata
else:
sharded_sd_metadata = self.unwrapped_checkpoint_io.load_content_metadata(checkpoint_path)
sharded_state_dict = {}
with sharded_state_context():
sharded_state_dict["state_dict"] = self.megatron_parallel.sharded_state_dict(metadata=sharded_sd_metadata)
Expand All @@ -1010,9 +1113,19 @@
)

try:
checkpoint = self.checkpoint_io.load_checkpoint(
checkpoint_path, sharded_state_dict=sharded_state_dict, strict=strict
)
if self.save_ckpt_format == "fsdp_dtensor":
sharded_state_dict["model"] = sharded_state_dict.pop("state_dict")
if "optimizer" in sharded_state_dict:
sharded_state_dict["optimizer_states"] = sharded_state_dict.pop("optimizer")[0]
checkpoint = self._load_fsdp_dtensor_checkpoint(
path=ckpt_to_dir(checkpoint_path),
sharded_state_dict=sharded_state_dict,
strict=strict,
)
else:
checkpoint = self.checkpoint_io.load_checkpoint(
checkpoint_path, sharded_state_dict=sharded_state_dict, strict=strict
)
except CheckpointException as e:
error_message = f"{e}\n{LOAD_ERROR}"
raise RuntimeError(error_message)
Expand All @@ -1031,7 +1144,10 @@
"""Metadata used for sharded_state_dict generation during checkpoint save."""
metadata = {}
if isinstance(self.ddp_config, DistributedDataParallelConfig) and self.ddp_config.use_distributed_optimizer:
if self.parallel_save_optim:
use_megatron_fsdp = getattr(self.ddp_config, "use_megatron_fsdp", False)
if use_megatron_fsdp:
metadata["distrib_optim_sharding_type"] = "fsdp_dtensor"
elif self.parallel_save_optim:
metadata["distrib_optim_sharding_type"] = "fully_sharded_model_space"
else:
metadata["distrib_optim_sharding_type"] = "dp_zero_gather_scatter"
Expand Down Expand Up @@ -1072,6 +1188,11 @@

mesh = DeviceMesh.from_group(parallel_state.get_data_parallel_group(), "cuda")

if self.save_ckpt_format == "fsdp_dtensor":
assert len(self.optimizers) == 1, "FSDP DTensor format requires a single optimizer."
self.optimizers[0].load_state_dict(checkpoint["optimizer_states"])
return

optimizer_states = checkpoint["optimizer"]
for optimizer, opt_state in zip(self.optimizers, optimizer_states):
if self._fsdp is not None:
Expand Down
19 changes: 13 additions & 6 deletions scripts/dit/dit_train.py
Original file line number Diff line number Diff line change
Expand Up @@ -191,6 +191,13 @@ def pretrain_l() -> run.Partial:
return recipe


def set_use_megatron_fsdp(recipe):
try:
recipe.trainer.strategy.ddp.use_megatron_fsdp = True
except AttributeError:
recipe.trainer.strategy.ddp.use_custom_fsdp = True


@run.cli.factory(target=llm.train)
def train_mock() -> run.Partial:
"""DiT Mock Pretraining Recipe"""
Expand All @@ -210,7 +217,7 @@ def train_mock() -> run.Partial:
recipe.data.model_config = recipe.model.config
recipe.log.log_dir = 'nemo_experiments/train_mock'

recipe.trainer.strategy.ddp.use_custom_fsdp = True
set_use_megatron_fsdp(recipe=recipe)
recipe.trainer.strategy.ddp.data_parallel_sharding_strategy = 'optim_grads_params'
recipe.trainer.strategy.ddp.overlap_param_gather = True
recipe.trainer.strategy.ddp.overlap_grad_reduce = True
Expand All @@ -236,7 +243,7 @@ def mock_ditllama5b_8k() -> run.Partial:
recipe.data.model_config = recipe.model.config
recipe.log.log_dir = 'nemo_experiments/mock_ditllama5b_8k'
recipe.model.config.attn_mask_type = AttnMaskType.no_mask
recipe.trainer.strategy.ddp.use_custom_fsdp = True
set_use_megatron_fsdp(recipe=recipe)
recipe.trainer.strategy.ddp.data_parallel_sharding_strategy = 'optim_grads_params'
recipe.trainer.strategy.ddp.overlap_param_gather = True
recipe.trainer.strategy.ddp.overlap_grad_reduce = True
Expand Down Expand Up @@ -360,7 +367,7 @@ def pretrain_ditllama30b() -> run.Partial:
recipe.data.task_encoder.seq_length = 256
recipe.data.virtual_epoch_length = 0
recipe.log.log_dir = 'nemo_experiments/ditllama30b_stage1_mock'
recipe.trainer.strategy.ddp.use_custom_fsdp = True
set_use_megatron_fsdp(recipe=recipe)
recipe.trainer.strategy.ddp.data_parallel_sharding_strategy = 'optim_grads_params'
recipe.trainer.strategy.ddp.overlap_param_gather = True
recipe.trainer.strategy.ddp.overlap_grad_reduce = True
Expand All @@ -386,7 +393,7 @@ def pretrain_ditllama30b_stage2_mock() -> run.Partial:
recipe.trainer.val_check_interval = 1.0
recipe.data.model_config = recipe.model.config
recipe.log.log_dir = 'nemo_experiments/ditllama30b_stage2_mock'
recipe.trainer.strategy.ddp.use_custom_fsdp = True
set_use_megatron_fsdp(recipe=recipe)
recipe.trainer.strategy.ddp.data_parallel_sharding_strategy = 'optim_grads_params'
recipe.trainer.strategy.ddp.overlap_param_gather = True
recipe.trainer.strategy.ddp.overlap_grad_reduce = True
Expand All @@ -412,7 +419,7 @@ def pretrain_ditllama30b_stage3_mock() -> run.Partial:
recipe.trainer.val_check_interval = 1.0
recipe.data.model_config = recipe.model.config
recipe.log.log_dir = 'nemo_experiments/ditllama30b_stage3_mock'
recipe.trainer.strategy.ddp.use_custom_fsdp = True
set_use_megatron_fsdp(recipe=recipe)
recipe.trainer.strategy.ddp.data_parallel_sharding_strategy = 'optim_grads_params'
recipe.trainer.strategy.ddp.overlap_param_gather = True
recipe.trainer.strategy.ddp.overlap_grad_reduce = True
Expand Down Expand Up @@ -512,7 +519,7 @@ def pretrain_ecditllama1b() -> run.Partial:
recipe.log.log_dir = 'nemo_experiments/ecditllama1b'
recipe.trainer.val_check_interval = 3000

recipe.trainer.strategy.ddp.use_custom_fsdp = True
set_use_megatron_fsdp(recipe=recipe)
recipe.trainer.strategy.ddp.data_parallel_sharding_strategy = 'optim_grads_params'
recipe.trainer.strategy.ddp.overlap_param_gather = True
recipe.trainer.strategy.ddp.overlap_grad_reduce = True
Expand Down
2 changes: 0 additions & 2 deletions scripts/flux/flux_controlnet_training.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,7 +92,6 @@ def flux_controlnet_training() -> run.Partial:
pipeline_dtype=torch.bfloat16,
ddp=run.Config(
DistributedDataParallelConfig,
use_custom_fsdp=True,
data_parallel_sharding_strategy='optim_grads_params',
check_for_nan_in_grad=True,
grad_reduce_in_fp32=True,
Expand Down Expand Up @@ -292,7 +291,6 @@ def unit_test(custom_fsdp=True) -> run.Partial:
def configure_custom_fsdp(recipe) -> run.Partial:
recipe.trainer.strategy.ddp = run.Config(
DistributedDataParallelConfig,
use_custom_fsdp=True,
data_parallel_sharding_strategy='optim_grads_params', # Custom FSDP
check_for_nan_in_grad=True,
grad_reduce_in_fp32=True,
Expand Down
Loading
Loading