Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 8 additions & 6 deletions src/transformers/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -3315,6 +3315,7 @@ def _save_checkpoint(self, model, trial):

run_dir = self._get_output_dir(trial=trial)
output_dir = os.path.join(run_dir, checkpoint_folder)
os.makedirs(output_dir, exist_ok=True)
self.save_model(output_dir, _internal_call=True)

if self.args.save_strategy in [SaveStrategy.STEPS, SaveStrategy.EPOCH] and self.state.best_global_step:
Expand All @@ -3330,6 +3331,13 @@ def _save_checkpoint(self, model, trial):
self._save_scaler(output_dir)
# Save RNG state
self._save_rng_state(output_dir)
elif self.is_fsdp_enabled and (
"SHARDED_STATE_DICT" in str(self.accelerator.state.fsdp_plugin.state_dict_type)
):
# self.save_model above only handles FULL_STATE_DICT
save_fsdp_model(
self.accelerator.state.fsdp_plugin, self.accelerator, self.model, output_dir, **_get_fsdp_ckpt_kwargs()
)

# Save the Trainer state
if self.args.should_save:
Expand Down Expand Up @@ -5384,12 +5392,6 @@ def create_accelerator_and_postprocess(self):
raise ValueError(
"`auto_find_batch_size` isn't supported yet with DeepSpeed Zero-3. Please consider using Zero-2, Zero-1, or FSDP"
)
if (
self.args.save_only_model
and self.is_fsdp_enabled
and "SHARDED_STATE_DICT" in str(self.accelerator.state.fsdp_plugin.state_dict_type)
):
raise ValueError("save_only_model option is not compatible with FSDP state dict type 'SHARDED_STATE_DICT'")

def propagate_args_to_deepspeed(self, auto_find_batch_size=False):
"""
Expand Down