Skip to content

Commit 795ed3e

Browse files
authored
Merge branch 'main' into maanug/perf-hf-env-var
2 parents 321cbca + cbb1192 commit 795ed3e

File tree

2 files changed

+20
-5
lines changed

2 files changed

+20
-5
lines changed

nemo/core/connectors/save_restore_connector.py

Lines changed: 17 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -756,8 +756,23 @@ def _save_state_dict_to_disk(state_dict, filepath):
756756
torch.save(state_dict, filepath)
757757

758758
@staticmethod
759-
def _load_state_dict_from_disk(model_weights, map_location=None):
760-
return torch.load(model_weights, map_location='cpu', weights_only=False)
759+
def _load_state_dict_from_disk(model_weights, map_location='cpu'):
760+
"""
761+
Load model state dict from disk.
762+
763+
Args:
764+
model_weights: Path to the checkpoint file
765+
map_location: Device to map tensors to
766+
767+
Returns:
768+
State dict loaded from checkpoint
769+
770+
"""
771+
try:
772+
return torch.load(model_weights, map_location=map_location, weights_only=True)
773+
except Exception as e:
774+
logging.error(f"Failed to load checkpoint with weights_only=True: {e}")
775+
raise e
761776

762777
@property
763778
def model_config_yaml(self) -> str:

requirements/requirements_lightning.txt

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,6 @@ torchmetrics>=0.11.0
88
transformers~=4.53.0
99
wandb
1010
webdataset>=0.2.86
11-
nv_one_logger_core>=2.3.0
12-
nv_one_logger_training_telemetry>=2.3.0
13-
nv_one_logger_pytorch_lightning_integration>=2.3.0
11+
nv_one_logger_core>=2.3.1
12+
nv_one_logger_training_telemetry>=2.3.1
13+
nv_one_logger_pytorch_lightning_integration>=2.3.1

0 commit comments

Comments
 (0)