Skip to content

Commit f8a59d8

Browse files
committed
revert back pretrain_llama3_8b.py
formt code Signed-off-by: jianbinc <shjwudp@gmail.com>
1 parent 0c17d92 commit f8a59d8

File tree

2 files changed

+7
-54
lines changed

2 files changed

+7
-54
lines changed

nemo/lightning/pytorch/strategies/megatron_strategy.py

Lines changed: 7 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -371,6 +371,8 @@ def __init__(
371371
"Setting FSDP option to megatron"
372372
)
373373
fsdp = 'megatron'
374+
if self.save_ckpt_format != "fsdp_dtensor":
375+
raise NotImplementedError(f"FSDP checkpointing is not supported with {self.save_ckpt_format}.")
374376

375377
if fsdp == "pytorch":
376378
raise NotImplementedError("PyTorch FSDP2 is not supported with MegatronParallel.")
@@ -936,28 +938,24 @@ def _get_fsdp_dtensor_state_dict(
936938
model_key="model",
937939
optimizer_key="optimizer_states",
938940
):
941+
from megatron.core.distributed.fsdp.src.megatron_fsdp.uneven_dtensor import (
942+
preprocess_state_dict_for_uneven_dtensor,
943+
)
939944
from megatron.core.transformer.fsdp_dtensor_checkpoint import (
940945
handle_fp8_extra_state_case,
941946
handle_swiglu_in_state_dict,
942947
)
943-
from megatron.core.distributed.fsdp.src.megatron_fsdp.uneven_dtensor import (
944-
preprocess_state_dict_for_uneven_dtensor,
945-
)
946948

947949
state_dict = raw_state_dict.copy()
948950
handle_fp8_extra_state_case(state_dict[model_key])
949951
module = self.model[0].module
950-
if torch.distributed.get_rank() == 0:
951-
print(self.model, module)
952952
if getattr(module.config, "gated_linear_unit", False):
953953
model_state_dict = state_dict[model_key].copy()
954954
if optimizer_key in state_dict:
955955
optimizer_state_dict = state_dict[optimizer_key].copy()
956956
else:
957957
optimizer_state_dict = {}
958-
handle_swiglu_in_state_dict(
959-
module.module, model_state_dict, optimizer_state_dict
960-
)
958+
handle_swiglu_in_state_dict(module.module, model_state_dict, optimizer_state_dict)
961959
state_dict[model_key] = model_state_dict
962960
if optimizer_key in state_dict:
963961
state_dict[optimizer_key] = optimizer_state_dict
@@ -1060,9 +1058,7 @@ def _load_fsdp_dtensor_checkpoint(self, path, sharded_state_dict, strict):
10601058
checkpoint_id=path,
10611059
planner=planner,
10621060
)
1063-
sharded_state_dict.update(
1064-
self._load_fsdp_dtensor_common_state(ckpt_dir=path)
1065-
)
1061+
sharded_state_dict.update(self._load_fsdp_dtensor_common_state(ckpt_dir=path))
10661062
if "loops" in sharded_state_dict:
10671063
sharded_state_dict["fit_loop"] = sharded_state_dict["loops"]["fit_loop"]
10681064

scripts/performance/llm/pretrain_llama3_8b.py

Lines changed: 0 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -84,49 +84,6 @@ def override_recipe_configs(
8484
recipe = set_exp_logging_configs(
8585
recipe, "pre_train", "llm", "llama3", args.tensorboard, args.wandb, args.wandb_prj_name, args.wandb_job_name
8686
)
87-
# for saving checkpoints
88-
ckpt_path = "/lustre/fsw/coreai_devtech_all/jianbinc/playground/nemo_nvfsdp_update/NeMo/checkpoints"
89-
recipe.log.log_dir = ckpt_path
90-
import nemo.lightning as nl
91-
import nemo_run as run
92-
93-
recipe.log.ckpt = run.Config(
94-
nl.ModelCheckpoint,
95-
train_time_interval=None,
96-
save_last=True,
97-
every_n_train_steps=100,
98-
save_top_k=1,
99-
save_on_train_epoch_end=True,
100-
save_optim_on_train_end=True,
101-
always_save_context=False,
102-
filename="{model_name}--{val_loss:.2f}-{step}-{consumed_samples}",
103-
)
104-
105-
# nl.ModelCheckpoint(
106-
# train_time_interval=None,
107-
# )
108-
# # recipe.log.ckpt.train_time_interval = None
109-
# recipe.log.ckpt.save_last = True
110-
# recipe.log.ckpt.every_n_train_steps = 100
111-
# recipe.log.ckpt.save_top_k = 1
112-
# recipe.log.ckpt.save_on_train_epoch_end = True
113-
# recipe.log.ckpt.save_optim_on_train_end = True
114-
# recipe.log.ckpt.always_save_context = False
115-
116-
# for loading checkpoints
117-
recipe.resume.resume_if_exists = True
118-
recipe.resume.resume_ignore_no_checkpoint = True
119-
# recipe.resume.restore_config = RestoreConfig(
120-
# path=ckpt_path,
121-
# load_model_state=True,
122-
# load_optim_state=True,
123-
# )
124-
125-
recipe.trainer.strategy.save_ckpt_format = "fsdp_dtensor"
126-
recipe.trainer.strategy.ddp.average_in_collective = False
127-
# recipe.trainer.strategy.ddp.data_parallel_sharding_strategy = "optim"
128-
129-
recipe.optim.config.use_precision_aware_optimizer = False
13087

13188
# data module configs
13289
if args.use_hf_tokenizer:

0 commit comments

Comments
 (0)