Skip to content
Merged
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion src/transformers/modeling_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -4907,7 +4907,9 @@ def _load_pretrained_model(
model_to_load, state_dict, start_prefix
)
# at this point the state dict should be on cpu, we don't need to actually read it
fixed_state_dict = model_to_load._fix_state_dict_keys_on_load(state_dict)
mismatched_names = [name for name, _, _ in mismatched_keys]
fixed_state_dict = {k: v for k, v in state_dict.items() if k not in mismatched_names}
fixed_state_dict = model_to_load._fix_state_dict_keys_on_load(fixed_state_dict)
model_to_load.load_state_dict(fixed_state_dict, strict=False, assign=assign_to_params_buffers)
Comment on lines +4910 to 4913
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

strict=False does not allow loading weights with mismatched sizes, filter those keys first

else:
# This should always be a list but, just to be sure.
Expand Down