We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
There was an error while loading. Please reload this page.
1 parent 08d2dfb commit cc88a51Copy full SHA for cc88a51
src/transformers/modeling_utils.py
@@ -4907,7 +4907,9 @@ def _load_pretrained_model(
4907
model_to_load, state_dict, start_prefix
4908
)
4909
# at this point the state dict should be on cpu, we don't need to actually read it
4910
- fixed_state_dict = model_to_load._fix_state_dict_keys_on_load(state_dict)
+ mismatched_names = [name for name, _, _ in mismatched_keys]
4911
+ fixed_state_dict = {k: v for k, v in state_dict.items() if k not in mismatched_names}
4912
+ fixed_state_dict = model_to_load._fix_state_dict_keys_on_load(fixed_state_dict)
4913
model_to_load.load_state_dict(fixed_state_dict, strict=False, assign=assign_to_params_buffers)
4914
else:
4915
# This should always be a list but, just to be sure.
0 commit comments