Skip to content

Commit cc88a51

Browse files
qubvelgarrett361
authored andcommitted
Fix loading models with mismatched sizes (huggingface#36463)
* Fix loading model with mismatched sizes * trigger tests
1 parent 08d2dfb commit cc88a51

File tree

1 file changed

+3
-1
lines changed

1 file changed

+3
-1
lines changed

src/transformers/modeling_utils.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4907,7 +4907,9 @@ def _load_pretrained_model(
49074907
model_to_load, state_dict, start_prefix
49084908
)
49094909
# at this point the state dict should be on cpu, we don't need to actually read it
4910-
fixed_state_dict = model_to_load._fix_state_dict_keys_on_load(state_dict)
4910+
mismatched_names = [name for name, _, _ in mismatched_keys]
4911+
fixed_state_dict = {k: v for k, v in state_dict.items() if k not in mismatched_names}
4912+
fixed_state_dict = model_to_load._fix_state_dict_keys_on_load(fixed_state_dict)
49114913
model_to_load.load_state_dict(fixed_state_dict, strict=False, assign=assign_to_params_buffers)
49124914
else:
49134915
# This should always be a list but, just to be sure.

0 commit comments

Comments
 (0)