Skip to content

Commit 703307e

Browse files
Fix config deprecation (#3129)
* Better deprecation message * Better deprecation message * Better doc string * Fixes * fix more * fix more * Improve __getattr__ * correct more * fix more * fix * Improve more * more improvements * fix more * Apply suggestions from code review Co-authored-by: Pedro Cuenca <[email protected]> * make style * Fix all rest & add tests & remove old deprecation fns --------- Co-authored-by: Pedro Cuenca <[email protected]>
1 parent ed8fd38 commit 703307e

22 files changed

+209
-146
lines changed

examples/community/unclip_image_interpolation.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -372,9 +372,9 @@ def __call__(
372372
self.decoder_scheduler.set_timesteps(decoder_num_inference_steps, device=device)
373373
decoder_timesteps_tensor = self.decoder_scheduler.timesteps
374374

375-
num_channels_latents = self.decoder.in_channels
376-
height = self.decoder.sample_size
377-
width = self.decoder.sample_size
375+
num_channels_latents = self.decoder.config.in_channels
376+
height = self.decoder.config.sample_size
377+
width = self.decoder.config.sample_size
378378

379379
decoder_latents = self.prepare_latents(
380380
(batch_size, num_channels_latents, height, width),
@@ -425,9 +425,9 @@ def __call__(
425425
self.super_res_scheduler.set_timesteps(super_res_num_inference_steps, device=device)
426426
super_res_timesteps_tensor = self.super_res_scheduler.timesteps
427427

428-
channels = self.super_res_first.in_channels // 2
429-
height = self.super_res_first.sample_size
430-
width = self.super_res_first.sample_size
428+
channels = self.super_res_first.config.in_channels // 2
429+
height = self.super_res_first.config.sample_size
430+
width = self.super_res_first.config.sample_size
431431

432432
super_res_latents = self.prepare_latents(
433433
(batch_size, channels, height, width),

examples/community/unclip_text_interpolation.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -452,9 +452,9 @@ def __call__(
452452
self.decoder_scheduler.set_timesteps(decoder_num_inference_steps, device=device)
453453
decoder_timesteps_tensor = self.decoder_scheduler.timesteps
454454

455-
num_channels_latents = self.decoder.in_channels
456-
height = self.decoder.sample_size
457-
width = self.decoder.sample_size
455+
num_channels_latents = self.decoder.config.in_channels
456+
height = self.decoder.config.sample_size
457+
width = self.decoder.config.sample_size
458458

459459
decoder_latents = self.prepare_latents(
460460
(batch_size, num_channels_latents, height, width),
@@ -505,9 +505,9 @@ def __call__(
505505
self.super_res_scheduler.set_timesteps(super_res_num_inference_steps, device=device)
506506
super_res_timesteps_tensor = self.super_res_scheduler.timesteps
507507

508-
channels = self.super_res_first.in_channels // 2
509-
height = self.super_res_first.sample_size
510-
width = self.super_res_first.sample_size
508+
channels = self.super_res_first.config.in_channels // 2
509+
height = self.super_res_first.config.sample_size
510+
width = self.super_res_first.config.sample_size
511511

512512
super_res_latents = self.prepare_latents(
513513
(batch_size, channels, height, width),

src/diffusers/configuration_utils.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -118,6 +118,24 @@ def register_to_config(self, **kwargs):
118118

119119
self._internal_dict = FrozenDict(internal_dict)
120120

121+
def __getattr__(self, name: str) -> Any:
122+
"""The only reason we overwrite `getattr` here is to gracefully deprecate accessing
123+
config attributes directly. See https://github.com/huggingface/diffusers/pull/3129
124+
125+
Tihs funtion is mostly copied from PyTorch's __getattr__ overwrite:
126+
https://pytorch.org/docs/stable/_modules/torch/nn/modules/module.html#Module
127+
"""
128+
129+
is_in_config = "_internal_dict" in self.__dict__ and hasattr(self.__dict__["_internal_dict"], name)
130+
is_attribute = name in self.__dict__
131+
132+
if is_in_config and not is_attribute:
133+
deprecation_message = f"Accessing config attribute `{name}` directly via '{type(self).__name__}' object attribute is deprecated. Please access '{name}' over '{type(self).__name__}'s config object instead, e.g. 'scheduler.config.{name}'."
134+
deprecate("direct config name access", "1.0.0", deprecation_message, standard_warn=False)
135+
return self._internal_dict[name]
136+
137+
raise AttributeError(f"'{type(self).__name__}' object has no attribute '{name}'")
138+
121139
def save_config(self, save_directory: Union[str, os.PathLike], push_to_hub: bool = False, **kwargs):
122140
"""
123141
Save a configuration object to the directory `save_directory`, so that it can be re-loaded using the

src/diffusers/models/autoencoder_kl.py

Lines changed: 1 addition & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@
1818
import torch.nn as nn
1919

2020
from ..configuration_utils import ConfigMixin, register_to_config
21-
from ..utils import BaseOutput, apply_forward_hook, deprecate
21+
from ..utils import BaseOutput, apply_forward_hook
2222
from .modeling_utils import ModelMixin
2323
from .vae import Decoder, DecoderOutput, DiagonalGaussianDistribution, Encoder
2424

@@ -123,16 +123,6 @@ def __init__(
123123
self.tile_latent_min_size = int(sample_size / (2 ** (len(self.config.block_out_channels) - 1)))
124124
self.tile_overlap_factor = 0.25
125125

126-
@property
127-
def block_out_channels(self):
128-
deprecate(
129-
"block_out_channels",
130-
"1.0.0",
131-
"Accessing `block_out_channels` directly via vae.block_out_channels is deprecated. Please use `vae.config.block_out_channels instead`",
132-
standard_warn=False,
133-
)
134-
return self.config.block_out_channels
135-
136126
def _set_gradient_checkpointing(self, module, value=False):
137127
if isinstance(module, (Encoder, Decoder)):
138128
module.gradient_checkpointing = value

src/diffusers/models/modeling_utils.py

Lines changed: 20 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717
import inspect
1818
import os
1919
from functools import partial
20-
from typing import Callable, List, Optional, Tuple, Union
20+
from typing import Any, Callable, List, Optional, Tuple, Union
2121

2222
import torch
2323
from torch import Tensor, device
@@ -32,6 +32,7 @@
3232
WEIGHTS_NAME,
3333
_add_variant,
3434
_get_model_file,
35+
deprecate,
3536
is_accelerate_available,
3637
is_safetensors_available,
3738
is_torch_version,
@@ -156,6 +157,24 @@ class ModelMixin(torch.nn.Module):
156157
def __init__(self):
157158
super().__init__()
158159

160+
def __getattr__(self, name: str) -> Any:
161+
"""The only reason we overwrite `getattr` here is to gracefully deprecate accessing
162+
config attributes directly. See https://github.com/huggingface/diffusers/pull/3129 We need to overwrite
163+
__getattr__ here in addition so that we don't trigger `torch.nn.Module`'s __getattr__':
164+
https://pytorch.org/docs/stable/_modules/torch/nn/modules/module.html#Module
165+
"""
166+
167+
is_in_config = "_internal_dict" in self.__dict__ and hasattr(self.__dict__["_internal_dict"], name)
168+
is_attribute = name in self.__dict__
169+
170+
if is_in_config and not is_attribute:
171+
deprecation_message = f"Accessing config attribute `{name}` directly via '{type(self).__name__}' object attribute is deprecated. Please access '{name}' over '{type(self).__name__}'s config object instead, e.g. 'unet.config.{name}'."
172+
deprecate("direct config name access", "1.0.0", deprecation_message, standard_warn=False, stacklevel=3)
173+
return self._internal_dict[name]
174+
175+
# call PyTorch's https://pytorch.org/docs/stable/_modules/torch/nn/modules/module.html#Module
176+
return super().__getattr__(name)
177+
159178
@property
160179
def is_gradient_checkpointing(self) -> bool:
161180
"""

src/diffusers/models/unet_1d.py

Lines changed: 1 addition & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@
1919
import torch.nn as nn
2020

2121
from ..configuration_utils import ConfigMixin, register_to_config
22-
from ..utils import BaseOutput, deprecate
22+
from ..utils import BaseOutput
2323
from .embeddings import GaussianFourierProjection, TimestepEmbedding, Timesteps
2424
from .modeling_utils import ModelMixin
2525
from .unet_1d_blocks import get_down_block, get_mid_block, get_out_block, get_up_block
@@ -190,16 +190,6 @@ def __init__(
190190
fc_dim=block_out_channels[-1] // 4,
191191
)
192192

193-
@property
194-
def in_channels(self):
195-
deprecate(
196-
"in_channels",
197-
"1.0.0",
198-
"Accessing `in_channels` directly via unet.in_channels is deprecated. Please use `unet.config.in_channels` instead",
199-
standard_warn=False,
200-
)
201-
return self.config.in_channels
202-
203193
def forward(
204194
self,
205195
sample: torch.FloatTensor,

src/diffusers/models/unet_2d.py

Lines changed: 1 addition & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@
1818
import torch.nn as nn
1919

2020
from ..configuration_utils import ConfigMixin, register_to_config
21-
from ..utils import BaseOutput, deprecate
21+
from ..utils import BaseOutput
2222
from .embeddings import GaussianFourierProjection, TimestepEmbedding, Timesteps
2323
from .modeling_utils import ModelMixin
2424
from .unet_2d_blocks import UNetMidBlock2D, get_down_block, get_up_block
@@ -216,16 +216,6 @@ def __init__(
216216
self.conv_act = nn.SiLU()
217217
self.conv_out = nn.Conv2d(block_out_channels[0], out_channels, kernel_size=3, padding=1)
218218

219-
@property
220-
def in_channels(self):
221-
deprecate(
222-
"in_channels",
223-
"1.0.0",
224-
"Accessing `in_channels` directly via unet.in_channels is deprecated. Please use `unet.config.in_channels` instead",
225-
standard_warn=False,
226-
)
227-
return self.config.in_channels
228-
229219
def forward(
230220
self,
231221
sample: torch.FloatTensor,

src/diffusers/models/unet_2d_condition.py

Lines changed: 1 addition & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@
2121

2222
from ..configuration_utils import ConfigMixin, register_to_config
2323
from ..loaders import UNet2DConditionLoadersMixin
24-
from ..utils import BaseOutput, deprecate, logging
24+
from ..utils import BaseOutput, logging
2525
from .attention_processor import AttentionProcessor, AttnProcessor
2626
from .embeddings import GaussianFourierProjection, TimestepEmbedding, Timesteps
2727
from .modeling_utils import ModelMixin
@@ -447,16 +447,6 @@ def __init__(
447447
block_out_channels[0], out_channels, kernel_size=conv_out_kernel, padding=conv_out_padding
448448
)
449449

450-
@property
451-
def in_channels(self):
452-
deprecate(
453-
"in_channels",
454-
"1.0.0",
455-
"Accessing `in_channels` directly via unet.in_channels is deprecated. Please use `unet.config.in_channels` instead",
456-
standard_warn=False,
457-
)
458-
return self.config.in_channels
459-
460450
@property
461451
def attn_processors(self) -> Dict[str, AttentionProcessor]:
462452
r"""

src/diffusers/pipelines/pipeline_utils.py

Lines changed: 33 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -508,7 +508,7 @@ def register_modules(self, **kwargs):
508508
setattr(self, name, module)
509509

510510
def __setattr__(self, name: str, value: Any):
511-
if hasattr(self, name) and hasattr(self.config, name):
511+
if name in self.__dict__ and hasattr(self.config, name):
512512
# We need to overwrite the config if name exists in config
513513
if isinstance(getattr(self.config, name), (tuple, list)):
514514
if value is not None and self.config[name][0] is not None:
@@ -648,26 +648,25 @@ def module_is_offloaded(module):
648648
)
649649

650650
module_names, _ = self._get_signature_keys(self)
651-
module_names = [m for m in module_names if hasattr(self, m)]
651+
modules = [getattr(self, n, None) for n in module_names]
652+
modules = [m for m in modules if isinstance(m, torch.nn.Module)]
652653

653654
is_offloaded = pipeline_is_offloaded or pipeline_is_sequentially_offloaded
654-
for name in module_names:
655-
module = getattr(self, name)
656-
if isinstance(module, torch.nn.Module):
657-
module.to(torch_device, torch_dtype)
658-
if (
659-
module.dtype == torch.float16
660-
and str(torch_device) in ["cpu"]
661-
and not silence_dtype_warnings
662-
and not is_offloaded
663-
):
664-
logger.warning(
665-
"Pipelines loaded with `torch_dtype=torch.float16` cannot run with `cpu` device. It"
666-
" is not recommended to move them to `cpu` as running them will fail. Please make"
667-
" sure to use an accelerator to run the pipeline in inference, due to the lack of"
668-
" support for`float16` operations on this device in PyTorch. Please, remove the"
669-
" `torch_dtype=torch.float16` argument, or use another device for inference."
670-
)
655+
for module in modules:
656+
module.to(torch_device, torch_dtype)
657+
if (
658+
module.dtype == torch.float16
659+
and str(torch_device) in ["cpu"]
660+
and not silence_dtype_warnings
661+
and not is_offloaded
662+
):
663+
logger.warning(
664+
"Pipelines loaded with `torch_dtype=torch.float16` cannot run with `cpu` device. It"
665+
" is not recommended to move them to `cpu` as running them will fail. Please make"
666+
" sure to use an accelerator to run the pipeline in inference, due to the lack of"
667+
" support for`float16` operations on this device in PyTorch. Please, remove the"
668+
" `torch_dtype=torch.float16` argument, or use another device for inference."
669+
)
671670
return self
672671

673672
@property
@@ -677,12 +676,12 @@ def device(self) -> torch.device:
677676
`torch.device`: The torch device on which the pipeline is located.
678677
"""
679678
module_names, _ = self._get_signature_keys(self)
680-
module_names = [m for m in module_names if hasattr(self, m)]
679+
modules = [getattr(self, n, None) for n in module_names]
680+
modules = [m for m in modules if isinstance(m, torch.nn.Module)]
681+
682+
for module in modules:
683+
return module.device
681684

682-
for name in module_names:
683-
module = getattr(self, name)
684-
if isinstance(module, torch.nn.Module):
685-
return module.device
686685
return torch.device("cpu")
687686

688687
@classmethod
@@ -1451,13 +1450,12 @@ def fn_recursive_set_mem_eff(module: torch.nn.Module):
14511450
for child in module.children():
14521451
fn_recursive_set_mem_eff(child)
14531452

1454-
module_names, _, _ = self.extract_init_dict(dict(self.config))
1455-
module_names = [m for m in module_names if hasattr(self, m)]
1453+
module_names, _ = self._get_signature_keys(self)
1454+
modules = [getattr(self, n, None) for n in module_names]
1455+
modules = [m for m in modules if isinstance(m, torch.nn.Module)]
14561456

1457-
for module_name in module_names:
1458-
module = getattr(self, module_name)
1459-
if isinstance(module, torch.nn.Module):
1460-
fn_recursive_set_mem_eff(module)
1457+
for module in modules:
1458+
fn_recursive_set_mem_eff(module)
14611459

14621460
def enable_attention_slicing(self, slice_size: Optional[Union[str, int]] = "auto"):
14631461
r"""
@@ -1484,10 +1482,9 @@ def disable_attention_slicing(self):
14841482
self.enable_attention_slicing(None)
14851483

14861484
def set_attention_slice(self, slice_size: Optional[int]):
1487-
module_names, _, _ = self.extract_init_dict(dict(self.config))
1488-
module_names = [m for m in module_names if hasattr(self, m)]
1485+
module_names, _ = self._get_signature_keys(self)
1486+
modules = [getattr(self, n, None) for n in module_names]
1487+
modules = [m for m in modules if isinstance(m, torch.nn.Module) and hasattr(m, "set_attention_slice")]
14891488

1490-
for module_name in module_names:
1491-
module = getattr(self, module_name)
1492-
if isinstance(module, torch.nn.Module) and hasattr(module, "set_attention_slice"):
1493-
module.set_attention_slice(slice_size)
1489+
for module in modules:
1490+
module.set_attention_slice(slice_size)

src/diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_zero.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -441,7 +441,7 @@ def __call__(
441441
timesteps = self.scheduler.timesteps
442442

443443
# Prepare latent variables
444-
num_channels_latents = self.unet.in_channels
444+
num_channels_latents = self.unet.config.in_channels
445445
latents = self.prepare_latents(
446446
batch_size * num_videos_per_prompt,
447447
num_channels_latents,

src/diffusers/pipelines/unclip/pipeline_unclip.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -413,9 +413,9 @@ def __call__(
413413
self.decoder_scheduler.set_timesteps(decoder_num_inference_steps, device=device)
414414
decoder_timesteps_tensor = self.decoder_scheduler.timesteps
415415

416-
num_channels_latents = self.decoder.in_channels
417-
height = self.decoder.sample_size
418-
width = self.decoder.sample_size
416+
num_channels_latents = self.decoder.config.in_channels
417+
height = self.decoder.config.sample_size
418+
width = self.decoder.config.sample_size
419419

420420
decoder_latents = self.prepare_latents(
421421
(batch_size, num_channels_latents, height, width),
@@ -466,9 +466,9 @@ def __call__(
466466
self.super_res_scheduler.set_timesteps(super_res_num_inference_steps, device=device)
467467
super_res_timesteps_tensor = self.super_res_scheduler.timesteps
468468

469-
channels = self.super_res_first.in_channels // 2
470-
height = self.super_res_first.sample_size
471-
width = self.super_res_first.sample_size
469+
channels = self.super_res_first.config.in_channels // 2
470+
height = self.super_res_first.config.sample_size
471+
width = self.super_res_first.config.sample_size
472472

473473
super_res_latents = self.prepare_latents(
474474
(batch_size, channels, height, width),

src/diffusers/pipelines/unclip/pipeline_unclip_image_variation.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -339,9 +339,9 @@ def __call__(
339339
self.decoder_scheduler.set_timesteps(decoder_num_inference_steps, device=device)
340340
decoder_timesteps_tensor = self.decoder_scheduler.timesteps
341341

342-
num_channels_latents = self.decoder.in_channels
343-
height = self.decoder.sample_size
344-
width = self.decoder.sample_size
342+
num_channels_latents = self.decoder.config.in_channels
343+
height = self.decoder.config.sample_size
344+
width = self.decoder.config.sample_size
345345

346346
if decoder_latents is None:
347347
decoder_latents = self.prepare_latents(
@@ -393,9 +393,9 @@ def __call__(
393393
self.super_res_scheduler.set_timesteps(super_res_num_inference_steps, device=device)
394394
super_res_timesteps_tensor = self.super_res_scheduler.timesteps
395395

396-
channels = self.super_res_first.in_channels // 2
397-
height = self.super_res_first.sample_size
398-
width = self.super_res_first.sample_size
396+
channels = self.super_res_first.config.in_channels // 2
397+
height = self.super_res_first.config.sample_size
398+
width = self.super_res_first.config.sample_size
399399

400400
if super_res_latents is None:
401401
super_res_latents = self.prepare_latents(

0 commit comments

Comments
 (0)