Skip to content

Commit f523b11

Browse files
Fix loading if unexpected keys are present (#3720)
* Fix loading * make style
1 parent 79fa94e commit f523b11

File tree

2 files changed

+19
-0
lines changed

2 files changed

+19
-0
lines changed

src/diffusers/models/modeling_utils.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
import inspect
1818
import itertools
1919
import os
20+
import re
2021
from functools import partial
2122
from typing import Any, Callable, List, Optional, Tuple, Union
2223

@@ -162,6 +163,7 @@ class ModelMixin(torch.nn.Module):
162163
config_name = CONFIG_NAME
163164
_automatically_saved_args = ["_diffusers_version", "_class_name", "_name_or_path"]
164165
_supports_gradient_checkpointing = False
166+
_keys_to_ignore_on_load_unexpected = None
165167

166168
def __init__(self):
167169
super().__init__()
@@ -608,13 +610,18 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P
608610
" `low_cpu_mem_usage=False` and `device_map=None` if you want to randomly initialize"
609611
" those weights or else make sure your checkpoint file is correct."
610612
)
613+
unexpected_keys = []
611614

612615
empty_state_dict = model.state_dict()
613616
for param_name, param in state_dict.items():
614617
accepts_dtype = "dtype" in set(
615618
inspect.signature(set_module_tensor_to_device).parameters.keys()
616619
)
617620

621+
if param_name not in empty_state_dict:
622+
unexpected_keys.append(param_name)
623+
continue
624+
618625
if empty_state_dict[param_name].shape != param.shape:
619626
raise ValueError(
620627
f"Cannot load {pretrained_model_name_or_path} because {param_name} expected shape {empty_state_dict[param_name]}, but got {param.shape}. If you want to instead overwrite randomly initialized weights, please make sure to pass both `low_cpu_mem_usage=False` and `ignore_mismatched_sizes=True`. For more information, see also: https://github.com/huggingface/diffusers/issues/1619#issuecomment-1345604389 as an example."
@@ -626,6 +633,16 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P
626633
)
627634
else:
628635
set_module_tensor_to_device(model, param_name, param_device, value=param)
636+
637+
if cls._keys_to_ignore_on_load_unexpected is not None:
638+
for pat in cls._keys_to_ignore_on_load_unexpected:
639+
unexpected_keys = [k for k in unexpected_keys if re.search(pat, k) is None]
640+
641+
if len(unexpected_keys) > 0:
642+
logger.warn(
643+
f"Some weights of the model checkpoint were not used when initializing {cls.__name__}: \n {[', '.join(unexpected_keys)]}"
644+
)
645+
629646
else: # else let accelerate handle loading and dispatching.
630647
# Load weights and dispatch according to the device_map
631648
# by default the device_map is None and the weights are loaded on the CPU

src/diffusers/pipelines/unidiffuser/modeling_text_decoder.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -61,6 +61,8 @@ class UniDiffuserTextDecoder(ModelMixin, ConfigMixin, ModuleUtilsMixin):
6161
dot-product/softmax to float() when training with mixed precision.
6262
"""
6363

64+
_keys_to_ignore_on_load_unexpected = [r"h\.\d+\.attn\.bias", r"h\.\d+\.attn\.masked_bias"]
65+
6466
@register_to_config
6567
def __init__(
6668
self,

0 commit comments

Comments
 (0)