Skip to content

Fix loading if unexpected keys are present #3720

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 17 additions & 0 deletions src/diffusers/models/modeling_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
import inspect
import itertools
import os
import re
from functools import partial
from typing import Any, Callable, List, Optional, Tuple, Union

Expand Down Expand Up @@ -162,6 +163,7 @@ class ModelMixin(torch.nn.Module):
config_name = CONFIG_NAME
_automatically_saved_args = ["_diffusers_version", "_class_name", "_name_or_path"]
_supports_gradient_checkpointing = False
_keys_to_ignore_on_load_unexpected = None

def __init__(self):
super().__init__()
Expand Down Expand Up @@ -608,13 +610,18 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P
" `low_cpu_mem_usage=False` and `device_map=None` if you want to randomly initialize"
" those weights or else make sure your checkpoint file is correct."
)
unexpected_keys = []

empty_state_dict = model.state_dict()
for param_name, param in state_dict.items():
accepts_dtype = "dtype" in set(
inspect.signature(set_module_tensor_to_device).parameters.keys()
)

if param_name not in empty_state_dict:
unexpected_keys.append(param_name)
continue

if empty_state_dict[param_name].shape != param.shape:
raise ValueError(
f"Cannot load {pretrained_model_name_or_path} because {param_name} expected shape {empty_state_dict[param_name]}, but got {param.shape}. If you want to instead overwrite randomly initialized weights, please make sure to pass both `low_cpu_mem_usage=False` and `ignore_mismatched_sizes=True`. For more information, see also: https://github.com/huggingface/diffusers/issues/1619#issuecomment-1345604389 as an example."
Expand All @@ -626,6 +633,16 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P
)
else:
set_module_tensor_to_device(model, param_name, param_device, value=param)

if cls._keys_to_ignore_on_load_unexpected is not None:
for pat in cls._keys_to_ignore_on_load_unexpected:
unexpected_keys = [k for k in unexpected_keys if re.search(pat, k) is None]

if len(unexpected_keys) > 0:
logger.warn(
f"Some weights of the model checkpoint were not used when initializing {cls.__name__}: \n {[', '.join(unexpected_keys)]}"
)

else: # else let accelerate handle loading and dispatching.
# Load weights and dispatch according to the device_map
# by default the device_map is None and the weights are loaded on the CPU
Expand Down
2 changes: 2 additions & 0 deletions src/diffusers/pipelines/unidiffuser/modeling_text_decoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,8 @@ class UniDiffuserTextDecoder(ModelMixin, ConfigMixin, ModuleUtilsMixin):
dot-product/softmax to float() when training with mixed precision.
"""

_keys_to_ignore_on_load_unexpected = [r"h\.\d+\.attn\.bias", r"h\.\d+\.attn\.masked_bias"]

@register_to_config
def __init__(
self,
Expand Down