Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 5 additions & 11 deletions src/accelerate/accelerator.py
Original file line number Diff line number Diff line change
Expand Up @@ -3357,10 +3357,7 @@ def register_load_state_pre_hook(self, hook: Callable[..., None]) -> hooks.Remov
def load_state(
self,
input_dir: str = None,
optimizer_load_kwargs: dict[str, Any] = {},
scheduler_load_kwargs: dict[str, Any] = {},
dataloader_load_kwargs: dict[str, Any] = {},
**load_model_func_kwargs,
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

please keep load_model_func_kwargs as load_model have different kwargs compared to load

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The load_model does not accept kwargs. Let me know if I am mistaken.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

in safetensors, we have the following

def load_model(
    model: torch.nn.Module, filename: Union[str, os.PathLike], strict: bool = True, device: Union[str, int] = "cpu"
) -> Tuple[List[str], List[str]]:

If you can revert the changes related to load_model_func_kwargs and only update load_kwargs to where we use load, it will be better

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The way it is being used it is missing only the strict argument, but that's ok. I reverted the load_model_func_kwargs.

c3d7f1f

**load_kwargs,
):
"""
Loads the current states of the model, optimizer, scaler, RNG generators, and registered objects.
Expand Down Expand Up @@ -3421,11 +3418,11 @@ def _inner(folder):
elif self.distributed_type == DistributedType.DEEPSPEED:
logger.info("Loading DeepSpeed Model and Optimizer")
ckpt_id = f"{MODEL_NAME}" if i == 0 else f"{MODEL_NAME}_{i}"
model.load_checkpoint(input_dir, ckpt_id, **load_model_func_kwargs)
model.load_checkpoint(input_dir, ckpt_id, **load_kwargs)
logger.info(f"DeepSpeed Model and Optimizer loaded from input dir {os.path.join(input_dir, ckpt_id)}")
elif self.distributed_type == DistributedType.MEGATRON_LM:
logger.info("Loading Megatron-LM Model, Optimizer and Scheduler")
model.load_checkpoint(input_dir)
model.load_checkpoint(input_dir, **load_kwargs)
logger.info(f"Megatron-LM Model , Optimizer and Scheduler loaded from input dir {input_dir}")
else:
models.append(model)
Expand Down Expand Up @@ -3457,7 +3454,7 @@ def _inner(folder):
for hook in self._load_model_state_pre_hook.values():
hook(models, input_dir)

map_location = load_model_func_kwargs.pop("map_location", None)
map_location = load_kwargs.pop("map_location", None)
if map_location is None:
if self.num_processes > 1 and self.distributed_type in (
DistributedType.MULTI_GPU,
Expand All @@ -3480,10 +3477,7 @@ def _inner(folder):
self.state.process_index,
self.scaler,
map_location,
optimizer_load_kwargs,
scheduler_load_kwargs,
dataloader_load_kwargs,
**load_model_func_kwargs,
**load_kwargs,
)
if "step" in override_attributes:
self.step = override_attributes["step"]
Expand Down
31 changes: 11 additions & 20 deletions src/accelerate/checkpointing.py
Original file line number Diff line number Diff line change
Expand Up @@ -180,10 +180,7 @@ def load_accelerator_state(
process_index,
scaler=None,
map_location=None,
optimizer_load_kwargs={},
scheduler_load_kwargs={},
dataloader_load_kwargs={},
**load_model_func_kwargs,
load_kwargs=None,
):
"""
Loads states of the models, optimizers, scaler, and RNG generators from a given directory.
Expand All @@ -203,14 +200,8 @@ def load_accelerator_state(
An optional *GradScaler* instance to load
map_location (`str`, *optional*):
What device to load the optimizer state onto. Should be one of either "cpu" or "on_device".
optimizer_load_kwargs (`dict`, *optional*):
Additional arguments that can be passed to the optimizer's `load` function.
scheduler_load_kwargs (`dict`, *optional*):
Additional arguments that can be passed to the scheduler's `load` function.
dataloader_load_kwargs (`dict`, *optional*):
Additional arguments that can be passed to the dataloader's `load` function.
load_model_func_kwargs (`dict`, *optional*):
Additional arguments that can be passed to the model's `load_state_dict` method.
load_kwargs (`dict`, *optional*):
Additional arguments that can be passed to the `load`, `load_model` and `load_state_dict` functions.

Returns:
`dict`: Contains the `Accelerator` attributes to override while loading the state.
Expand All @@ -232,28 +223,28 @@ def load_accelerator_state(
ending = f"_{i}" if i > 0 else ""
input_model_file = input_dir.joinpath(f"{SAFE_MODEL_NAME}{ending}.safetensors")
if input_model_file.exists():
load_model(model, input_model_file, device=str(map_location), **load_model_func_kwargs)
load_model(model, input_model_file, device=str(map_location), **load_kwargs)
else:
# Load with torch
input_model_file = input_dir.joinpath(f"{MODEL_NAME}{ending}.bin")
state_dict = load(input_model_file, map_location=map_location)
model.load_state_dict(state_dict, **load_model_func_kwargs)
model.load_state_dict(state_dict, **load_kwargs)
logger.info("All model weights loaded successfully")

# Optimizer states
for i, opt in enumerate(optimizers):
optimizer_name = f"{OPTIMIZER_NAME}.bin" if i == 0 else f"{OPTIMIZER_NAME}_{i}.bin"
input_optimizer_file = input_dir.joinpath(optimizer_name)
optimizer_state = load(input_optimizer_file, map_location=map_location, **optimizer_load_kwargs)
optimizers[i].load_state_dict(optimizer_state)
optimizer_state = load(input_optimizer_file, map_location=map_location, **load_kwargs)
optimizers[i].load_state_dict(optimizer_state, **load_kwargs)
logger.info("All optimizer states loaded successfully")

# Scheduler states
for i, scheduler in enumerate(schedulers):
scheduler_name = f"{SCHEDULER_NAME}.bin" if i == 0 else f"{SCHEDULER_NAME}_{i}.bin"
input_scheduler_file = input_dir.joinpath(scheduler_name)
scheduler_state = load(input_scheduler_file, map_location=None, **scheduler_load_kwargs)
scheduler.load_state_dict(scheduler_state)
scheduler_state = load(input_scheduler_file, map_location=None, **load_kwargs)
scheduler.load_state_dict(scheduler_state, **load_kwargs)
logger.info("All scheduler states loaded successfully")

for i, dataloader in enumerate(dataloaders):
Expand All @@ -270,8 +261,8 @@ def load_accelerator_state(
dataloader_state_dict_name = "dl_state_dict.bin" if i == 0 else f"dl_state_dict_{i}.bin"
input_dataloader_state_dict_file = input_dir.joinpath(dataloader_state_dict_name)
if input_dataloader_state_dict_file.exists():
state_dict = load(input_dataloader_state_dict_file, map_location=None, **dataloader_load_kwargs)
dataloader.load_state_dict(state_dict)
state_dict = load(input_dataloader_state_dict_file, map_location=None, **load_kwargs)
dataloader.load_state_dict(state_dict, **load_kwargs)
logger.info("All dataloader sampler states loaded successfully")

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

let's also include it for scaler and states

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

# GradScaler state
Expand Down