Skip to content

Commit 72a9489

Browse files
committed
deprecate AttentionBlock
1 parent a39d42b commit 72a9489

17 files changed

+459
-128
lines changed

examples/research_projects/onnxruntime/unconditional_image_generation/train_unconditional.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -350,6 +350,7 @@ def main(args):
350350
"UpBlock2D",
351351
"UpBlock2D",
352352
),
353+
attention_block_type="Attention",
353354
)
354355

355356
# Create EMA for the model.

examples/unconditional_image_generation/train_unconditional.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -397,10 +397,12 @@ def load_model_hook(models, input_dir):
397397
"UpBlock2D",
398398
"UpBlock2D",
399399
),
400+
attention_block_type="Attention",
400401
)
401402
else:
402403
config = UNet2DModel.load_config(args.model_config_name_or_path)
403404
model = UNet2DModel.from_config(config)
405+
model._convert_deprecated_attention_blocks()
404406

405407
# Create EMA for the model.
406408
if args.use_ema:

src/diffusers/models/attention.py

Lines changed: 53 additions & 102 deletions
Original file line numberDiff line numberDiff line change
@@ -11,27 +11,23 @@
1111
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
14-
import math
1514
from typing import Callable, Optional
1615

1716
import torch
1817
import torch.nn.functional as F
1918
from torch import nn
2019

21-
from ..utils.import_utils import is_xformers_available
22-
from .attention_processor import Attention
20+
from ..utils import deprecate
21+
from .attention_processor import Attention, SpatialAttnProcessor
2322
from .embeddings import CombinedTimestepLabelEmbeddings
2423

2524

26-
if is_xformers_available():
27-
import xformers
28-
import xformers.ops
29-
else:
30-
xformers = None
31-
32-
3325
class AttentionBlock(nn.Module):
3426
"""
27+
This class is deprecated. Its forward method will throw an error. On model load, we convert all instances of
28+
`AttentionBlock` to `diffusers.models.attention_processor.Attention`, see
29+
`ModelMixin#_convert_deprecated_attention_blocks`.
30+
3531
An attention block that allows spatial positions to attend to each other. Originally ported from here, but adapted
3632
to the N-d case.
3733
https://github.com/hojonathanho/diffusion/blob/1e0dceb3b3495bbe19116a5e1b3596cd0706c543/diffusion_tf/models/unet.py#L66.
@@ -46,8 +42,6 @@ class AttentionBlock(nn.Module):
4642
eps (`float`, *optional*, defaults to 1e-5): The epsilon value to use for group norm.
4743
"""
4844

49-
# IMPORTANT;TODO(Patrick, William) - this class will be deprecated soon. Do not use it anymore
50-
5145
def __init__(
5246
self,
5347
channels: int,
@@ -57,6 +51,16 @@ def __init__(
5751
eps: float = 1e-5,
5852
):
5953
super().__init__()
54+
55+
deprecation_message = (
56+
"`AttentionBlock` has been deprecated and will be replaced with `diffusers.models.attention_processor.Attention`."
57+
" The DiffusionPipeline loading this block in is auto converting it to `diffusers.models.attention_processor.Attention`."
58+
" Please call `DiffusionPipeline#save_pretrained` and re-upload the pipeline to the hub."
59+
" If you are only loading a model instead of a whole pipeline, the same instructions apply with `ModelMixin#save_pretrained`."
60+
)
61+
62+
deprecate("AttentionBlock", "0.18.0", deprecation_message, standard_warn=True)
63+
6064
self.channels = channels
6165

6266
self.num_heads = channels // num_head_channels if num_head_channels is not None else 1
@@ -71,107 +75,54 @@ def __init__(
7175
self.rescale_output_factor = rescale_output_factor
7276
self.proj_attn = nn.Linear(channels, channels, bias=True)
7377

74-
self._use_memory_efficient_attention_xformers = False
75-
self._attention_op = None
76-
77-
def reshape_heads_to_batch_dim(self, tensor):
78-
batch_size, seq_len, dim = tensor.shape
79-
head_size = self.num_heads
80-
tensor = tensor.reshape(batch_size, seq_len, head_size, dim // head_size)
81-
tensor = tensor.permute(0, 2, 1, 3).reshape(batch_size * head_size, seq_len, dim // head_size)
82-
return tensor
83-
84-
def reshape_batch_dim_to_heads(self, tensor):
85-
batch_size, seq_len, dim = tensor.shape
86-
head_size = self.num_heads
87-
tensor = tensor.reshape(batch_size // head_size, head_size, seq_len, dim)
88-
tensor = tensor.permute(0, 2, 1, 3).reshape(batch_size // head_size, seq_len, dim * head_size)
89-
return tensor
90-
9178
def set_use_memory_efficient_attention_xformers(
9279
self, use_memory_efficient_attention_xformers: bool, attention_op: Optional[Callable] = None
9380
):
94-
if use_memory_efficient_attention_xformers:
95-
if not is_xformers_available():
96-
raise ModuleNotFoundError(
97-
(
98-
"Refer to https://github.com/facebookresearch/xformers for more information on how to install"
99-
" xformers"
100-
),
101-
name="xformers",
102-
)
103-
elif not torch.cuda.is_available():
104-
raise ValueError(
105-
"torch.cuda.is_available() should be True but is False. xformers' memory efficient attention is"
106-
" only available for GPU "
107-
)
108-
else:
109-
try:
110-
# Make sure we can run the memory efficient attention
111-
_ = xformers.ops.memory_efficient_attention(
112-
torch.randn((1, 2, 40), device="cuda"),
113-
torch.randn((1, 2, 40), device="cuda"),
114-
torch.randn((1, 2, 40), device="cuda"),
115-
)
116-
except Exception as e:
117-
raise e
118-
self._use_memory_efficient_attention_xformers = use_memory_efficient_attention_xformers
119-
self._attention_op = attention_op
81+
raise ValueError(
82+
"`AttentionBlock` should have been converted after load to `diffusers.models.attention_processor.Attention`"
83+
)
12084

12185
def forward(self, hidden_states):
122-
residual = hidden_states
123-
batch, channel, height, width = hidden_states.shape
124-
125-
# norm
126-
hidden_states = self.group_norm(hidden_states)
127-
128-
hidden_states = hidden_states.view(batch, channel, height * width).transpose(1, 2)
129-
130-
# proj to q, k, v
131-
query_proj = self.query(hidden_states)
132-
key_proj = self.key(hidden_states)
133-
value_proj = self.value(hidden_states)
134-
135-
scale = 1 / math.sqrt(self.channels / self.num_heads)
136-
137-
query_proj = self.reshape_heads_to_batch_dim(query_proj)
138-
key_proj = self.reshape_heads_to_batch_dim(key_proj)
139-
value_proj = self.reshape_heads_to_batch_dim(value_proj)
86+
raise ValueError(
87+
"`AttentionBlock` should have been converted after load to `diffusers.models.attention_processor.Attention`"
88+
)
14089

141-
if self._use_memory_efficient_attention_xformers:
142-
# Memory efficient attention
143-
hidden_states = xformers.ops.memory_efficient_attention(
144-
query_proj, key_proj, value_proj, attn_bias=None, op=self._attention_op
145-
)
146-
hidden_states = hidden_states.to(query_proj.dtype)
90+
def _as_attention_processor_attention(self):
91+
if self.num_head_size is None:
92+
# When `self.num_head_size` is None, there is a single attention head
93+
# of all the channels
94+
dim_head = self.channels
14795
else:
148-
attention_scores = torch.baddbmm(
149-
torch.empty(
150-
query_proj.shape[0],
151-
query_proj.shape[1],
152-
key_proj.shape[1],
153-
dtype=query_proj.dtype,
154-
device=query_proj.device,
155-
),
156-
query_proj,
157-
key_proj.transpose(-1, -2),
158-
beta=0,
159-
alpha=scale,
160-
)
161-
attention_probs = torch.softmax(attention_scores.float(), dim=-1).type(attention_scores.dtype)
162-
hidden_states = torch.bmm(attention_probs, value_proj)
96+
dim_head = self.num_head_size
97+
98+
# This will allocate some additional memory but as this is only done once during model load,
99+
# it should be ok.
100+
attn = Attention(
101+
self.channels,
102+
heads=self.num_heads,
103+
dim_head=dim_head,
104+
bias=True,
105+
upcast_softmax=True,
106+
norm_num_groups=self.group_norm.num_groups,
107+
processor=SpatialAttnProcessor(),
108+
eps=self.group_norm.eps,
109+
rescale_output_factor=self.rescale_output_factor,
110+
)
163111

164-
# reshape hidden_states
165-
hidden_states = self.reshape_batch_dim_to_heads(hidden_states)
112+
param = next(self.parameters())
166113

167-
# compute next hidden_states
168-
hidden_states = self.proj_attn(hidden_states)
114+
device = param.device
115+
dtype = param.dtype
169116

170-
hidden_states = hidden_states.transpose(-1, -2).reshape(batch, channel, height, width)
117+
attn.to(device=device, dtype=dtype)
171118

172-
# res connect and rescale
173-
hidden_states = (hidden_states + residual) / self.rescale_output_factor
174-
return hidden_states
119+
attn.group_norm.load_state_dict(self.group_norm.state_dict())
120+
attn.to_q.load_state_dict(self.query.state_dict())
121+
attn.to_k.load_state_dict(self.key.state_dict())
122+
attn.to_v.load_state_dict(self.value.state_dict())
123+
attn.to_out[0].load_state_dict(self.proj_attn.state_dict())
124+
125+
return attn
175126

176127

177128
class BasicTransformerBlock(nn.Module):

src/diffusers/models/attention_processor.py

Lines changed: 102 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -62,13 +62,16 @@ def __init__(
6262
out_bias: bool = True,
6363
scale_qk: bool = True,
6464
processor: Optional["AttnProcessor"] = None,
65+
eps: float = 1e-5,
66+
rescale_output_factor: float = 1.0,
6567
):
6668
super().__init__()
6769
inner_dim = dim_head * heads
6870
cross_attention_dim = cross_attention_dim if cross_attention_dim is not None else query_dim
6971
self.upcast_attention = upcast_attention
7072
self.upcast_softmax = upcast_softmax
7173
self.cross_attention_norm = cross_attention_norm
74+
self.rescale_output_factor = rescale_output_factor
7275

7376
self.scale = dim_head**-0.5 if scale_qk else 1.0
7477

@@ -81,7 +84,7 @@ def __init__(
8184
self.added_kv_proj_dim = added_kv_proj_dim
8285

8386
if norm_num_groups is not None:
84-
self.group_norm = nn.GroupNorm(num_channels=inner_dim, num_groups=norm_num_groups, eps=1e-5, affine=True)
87+
self.group_norm = nn.GroupNorm(num_channels=inner_dim, num_groups=norm_num_groups, eps=eps, affine=True)
8588
else:
8689
self.group_norm = None
8790

@@ -117,6 +120,10 @@ def set_use_memory_efficient_attention_xformers(
117120
self.processor, (LoRAAttnProcessor, LoRAXFormersAttnProcessor)
118121
)
119122

123+
is_spatial_attention = hasattr(self, "processor") and isinstance(
124+
self.processor, (SpatialAttnProcessor, XFormersSpatialAttnProcessor)
125+
)
126+
120127
if use_memory_efficient_attention_xformers:
121128
if self.added_kv_proj_dim is not None:
122129
# TODO(Anton, Patrick, Suraj, William) - currently xformers doesn't work for UnCLIP
@@ -159,6 +166,8 @@ def set_use_memory_efficient_attention_xformers(
159166
)
160167
processor.load_state_dict(self.processor.state_dict())
161168
processor.to(self.processor.to_q_lora.up.weight.device)
169+
elif is_spatial_attention:
170+
processor = XFormersSpatialAttnProcessor()
162171
else:
163172
processor = XFormersAttnProcessor(attention_op=attention_op)
164173
else:
@@ -170,6 +179,8 @@ def set_use_memory_efficient_attention_xformers(
170179
)
171180
processor.load_state_dict(self.processor.state_dict())
172181
processor.to(self.processor.to_q_lora.up.weight.device)
182+
elif is_spatial_attention:
183+
processor = SpatialAttnProcessor()
173184
else:
174185
processor = AttnProcessor()
175186

@@ -684,6 +695,94 @@ def __call__(self, attn: "Attention", hidden_states, encoder_hidden_states=None,
684695
return hidden_states
685696

686697

698+
class SpatialAttnProcessor:
699+
def __call__(self, attn: Attention, hidden_states, encoder_hidden_states=None, attention_mask=None):
700+
if attention_mask is not None:
701+
raise ValueError(f"{self.__class__.__name__} does not support `attention_mask`")
702+
703+
if encoder_hidden_states is not None:
704+
raise ValueError(f"{self.__class__.__name__} does not support `encoder_hidden_states`")
705+
706+
residual = hidden_states
707+
batch, channel, height, width = hidden_states.shape
708+
709+
# norm
710+
hidden_states = attn.group_norm(hidden_states)
711+
712+
hidden_states = hidden_states.view(batch, channel, height * width).transpose(1, 2)
713+
714+
# proj to q, k, v
715+
query_proj = attn.to_q(hidden_states)
716+
key_proj = attn.to_k(hidden_states)
717+
value_proj = attn.to_v(hidden_states)
718+
719+
query_proj = attn.head_to_batch_dim(query_proj)
720+
key_proj = attn.head_to_batch_dim(key_proj)
721+
value_proj = attn.head_to_batch_dim(value_proj)
722+
723+
attention_probs = attn.get_attention_scores(query_proj, key_proj)
724+
hidden_states = torch.bmm(attention_probs, value_proj)
725+
726+
# reshape hidden_states
727+
hidden_states = attn.batch_to_head_dim(hidden_states)
728+
729+
# compute next hidden_states
730+
hidden_states = attn.to_out[0](hidden_states)
731+
732+
hidden_states = hidden_states.transpose(-1, -2).reshape(batch, channel, height, width)
733+
734+
# res connect and rescale
735+
hidden_states = (hidden_states + residual) / attn.rescale_output_factor
736+
return hidden_states
737+
738+
739+
class XFormersSpatialAttnProcessor:
740+
def __init__(self, attention_op: Optional[Callable] = None):
741+
self.attention_op = attention_op
742+
743+
def __call__(self, attn: Attention, hidden_states, encoder_hidden_states=None, attention_mask=None):
744+
if attention_mask is not None:
745+
raise ValueError(f"{self.__class__.__name__} does not support `attention_mask`")
746+
747+
if encoder_hidden_states is not None:
748+
raise ValueError(f"{self.__class__.__name__} does not support `encoder_hidden_states`")
749+
750+
residual = hidden_states
751+
batch, channel, height, width = hidden_states.shape
752+
753+
# norm
754+
hidden_states = attn.group_norm(hidden_states)
755+
756+
hidden_states = hidden_states.view(batch, channel, height * width).transpose(1, 2)
757+
758+
# proj to q, k, v
759+
query_proj = attn.to_q(hidden_states)
760+
key_proj = attn.to_k(hidden_states)
761+
value_proj = attn.to_v(hidden_states)
762+
763+
query_proj = attn.head_to_batch_dim(query_proj)
764+
key_proj = attn.head_to_batch_dim(key_proj)
765+
value_proj = attn.head_to_batch_dim(value_proj)
766+
767+
# Memory efficient attention
768+
hidden_states = xformers.ops.memory_efficient_attention(
769+
query_proj, key_proj, value_proj, attn_bias=None, op=self.attention_op, scale=attn.scale
770+
)
771+
hidden_states = hidden_states.to(query_proj.dtype)
772+
773+
# reshape hidden_states
774+
hidden_states = attn.batch_to_head_dim(hidden_states)
775+
776+
# compute next hidden_states
777+
hidden_states = attn.to_out[0](hidden_states)
778+
779+
hidden_states = hidden_states.transpose(-1, -2).reshape(batch, channel, height, width)
780+
781+
# res connect and rescale
782+
hidden_states = (hidden_states + residual) / attn.rescale_output_factor
783+
return hidden_states
784+
785+
687786
AttentionProcessor = Union[
688787
AttnProcessor,
689788
XFormersAttnProcessor,
@@ -692,4 +791,6 @@ def __call__(self, attn: "Attention", hidden_states, encoder_hidden_states=None,
692791
SlicedAttnAddedKVProcessor,
693792
LoRAAttnProcessor,
694793
LoRAXFormersAttnProcessor,
794+
SpatialAttnProcessor,
795+
XFormersSpatialAttnProcessor,
695796
]

0 commit comments

Comments
 (0)