diff --git a/src/diffusers/models/modeling_utils.py b/src/diffusers/models/modeling_utils.py index c9fabf93253b..f6d6bc5711cd 100644 --- a/src/diffusers/models/modeling_utils.py +++ b/src/diffusers/models/modeling_utils.py @@ -17,6 +17,7 @@ import inspect import itertools import os +import re from functools import partial from typing import Any, Callable, List, Optional, Tuple, Union @@ -162,6 +163,7 @@ class ModelMixin(torch.nn.Module): config_name = CONFIG_NAME _automatically_saved_args = ["_diffusers_version", "_class_name", "_name_or_path"] _supports_gradient_checkpointing = False + _keys_to_ignore_on_load_unexpected = None def __init__(self): super().__init__() @@ -608,6 +610,7 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P " `low_cpu_mem_usage=False` and `device_map=None` if you want to randomly initialize" " those weights or else make sure your checkpoint file is correct." ) + unexpected_keys = [] empty_state_dict = model.state_dict() for param_name, param in state_dict.items(): @@ -615,6 +618,10 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P inspect.signature(set_module_tensor_to_device).parameters.keys() ) + if param_name not in empty_state_dict: + unexpected_keys.append(param_name) + continue + if empty_state_dict[param_name].shape != param.shape: raise ValueError( f"Cannot load {pretrained_model_name_or_path} because {param_name} expected shape {empty_state_dict[param_name]}, but got {param.shape}. If you want to instead overwrite randomly initialized weights, please make sure to pass both `low_cpu_mem_usage=False` and `ignore_mismatched_sizes=True`. For more information, see also: https://github.com/huggingface/diffusers/issues/1619#issuecomment-1345604389 as an example." @@ -626,6 +633,16 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P ) else: set_module_tensor_to_device(model, param_name, param_device, value=param) + + if cls._keys_to_ignore_on_load_unexpected is not None: + for pat in cls._keys_to_ignore_on_load_unexpected: + unexpected_keys = [k for k in unexpected_keys if re.search(pat, k) is None] + + if len(unexpected_keys) > 0: + logger.warn( + f"Some weights of the model checkpoint were not used when initializing {cls.__name__}: \n {[', '.join(unexpected_keys)]}" + ) + else: # else let accelerate handle loading and dispatching. # Load weights and dispatch according to the device_map # by default the device_map is None and the weights are loaded on the CPU diff --git a/src/diffusers/pipelines/unidiffuser/modeling_text_decoder.py b/src/diffusers/pipelines/unidiffuser/modeling_text_decoder.py index febc8e09e6ab..9b962f6e0656 100644 --- a/src/diffusers/pipelines/unidiffuser/modeling_text_decoder.py +++ b/src/diffusers/pipelines/unidiffuser/modeling_text_decoder.py @@ -61,6 +61,8 @@ class UniDiffuserTextDecoder(ModelMixin, ConfigMixin, ModuleUtilsMixin): dot-product/softmax to float() when training with mixed precision. """ + _keys_to_ignore_on_load_unexpected = [r"h\.\d+\.attn\.bias", r"h\.\d+\.attn\.masked_bias"] + @register_to_config def __init__( self,