Skip to content
Merged
Show file tree
Hide file tree
Changes from 16 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 8 additions & 4 deletions src/accelerate/accelerator.py
Original file line number Diff line number Diff line change
Expand Up @@ -437,9 +437,9 @@ def __init__(
self.has_fp8_handler = False
if kwargs_handlers is not None:
for handler in kwargs_handlers:
assert isinstance(handler, KwargsHandler), (
f"Unsupported kwargs handler passed: {handler}, must be one that inherits `accelerate.utils.KwargsHandler`."
)
assert isinstance(
handler, KwargsHandler
), f"Unsupported kwargs handler passed: {handler}, must be one that inherits `accelerate.utils.KwargsHandler`."
# Add the handler class to the set of found handlers
if handler.__class__ in found_handlers:
raise ValueError(f"You can only pass one {handler.__class__} in `kwargs_handlers`.")
Expand Down Expand Up @@ -3358,7 +3358,7 @@ def register_load_state_pre_hook(self, hook: Callable[..., None]) -> hooks.Remov
self._load_model_state_pre_hook[handle.id] = hook
return handle

def load_state(self, input_dir: str = None, **load_model_func_kwargs):
def load_state(self, input_dir: str = None, load_kwargs: dict | None = None, **load_model_func_kwargs):
"""
Loads the current states of the model, optimizer, scaler, RNG generators, and registered objects.

Expand All @@ -3373,6 +3373,9 @@ def load_state(self, input_dir: str = None, **load_model_func_kwargs):
input_dir (`str` or `os.PathLike`):
The name of the folder all relevant weights and states were saved in. Can be `None` if
`automatic_checkpoint_naming` is used, and will pick up from the latest checkpoint.
load_kwargs (`dict`, *optional*):
Additional keyword arguments for the underlying `load` function, such as optional arguments for
state_dict and optimizer on.
load_model_func_kwargs (`dict`, *optional*):
Additional keyword arguments for loading model which can be passed to the underlying load function,
such as optional arguments for DeepSpeed's `load_checkpoint` function or a `map_location` to load the
Expand Down Expand Up @@ -3477,6 +3480,7 @@ def _inner(folder):
self.state.process_index,
self.scaler,
map_location,
load_kwargs,
**load_model_func_kwargs,
)
if "step" in override_attributes:
Expand Down
9 changes: 6 additions & 3 deletions src/accelerate/checkpointing.py
Original file line number Diff line number Diff line change
Expand Up @@ -180,6 +180,7 @@ def load_accelerator_state(
process_index,
scaler=None,
map_location=None,
load_kwargs=None,
**load_model_func_kwargs,
):
"""
Expand All @@ -200,6 +201,8 @@ def load_accelerator_state(
An optional *GradScaler* instance to load
map_location (`str`, *optional*):
What device to load the optimizer state onto. Should be one of either "cpu" or "on_device".
load_kwargs (`dict`, *optional*):
Additional arguments that can be passed to the `load` function.
load_model_func_kwargs (`dict`, *optional*):
Additional arguments that can be passed to the model's `load_state_dict` method.

Expand Down Expand Up @@ -235,15 +238,15 @@ def load_accelerator_state(
for i, opt in enumerate(optimizers):
optimizer_name = f"{OPTIMIZER_NAME}.bin" if i == 0 else f"{OPTIMIZER_NAME}_{i}.bin"
input_optimizer_file = input_dir.joinpath(optimizer_name)
optimizer_state = load(input_optimizer_file, map_location=map_location)
optimizer_state = load(input_optimizer_file, map_location=map_location, **load_kwargs)
optimizers[i].load_state_dict(optimizer_state)
logger.info("All optimizer states loaded successfully")

# Scheduler states
for i, scheduler in enumerate(schedulers):
scheduler_name = f"{SCHEDULER_NAME}.bin" if i == 0 else f"{SCHEDULER_NAME}_{i}.bin"
input_scheduler_file = input_dir.joinpath(scheduler_name)
scheduler_state = load(input_scheduler_file)
scheduler_state = load(input_scheduler_file, **load_kwargs)
scheduler.load_state_dict(scheduler_state)
logger.info("All scheduler states loaded successfully")

Expand All @@ -261,7 +264,7 @@ def load_accelerator_state(
dataloader_state_dict_name = "dl_state_dict.bin" if i == 0 else f"dl_state_dict_{i}.bin"
input_dataloader_state_dict_file = input_dir.joinpath(dataloader_state_dict_name)
if input_dataloader_state_dict_file.exists():
state_dict = load(input_dataloader_state_dict_file)
state_dict = load(input_dataloader_state_dict_file, **load_kwargs)
dataloader.load_state_dict(state_dict)
logger.info("All dataloader sampler states loaded successfully")

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

let's also include it for scaler and states

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Expand Down