11
11
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
12
# See the License for the specific language governing permissions and
13
13
# limitations under the License.
14
- import math
15
14
from typing import Callable , Optional
16
15
17
16
import torch
18
17
import torch .nn .functional as F
19
18
from torch import nn
20
19
21
- from ..utils . import_utils import is_xformers_available
22
- from .attention_processor import Attention
20
+ from ..utils import deprecate
21
+ from .attention_processor import Attention , SpatialAttnProcessor
23
22
from .embeddings import CombinedTimestepLabelEmbeddings
24
23
25
24
26
- if is_xformers_available ():
27
- import xformers
28
- import xformers .ops
29
- else :
30
- xformers = None
31
-
32
-
33
25
class AttentionBlock (nn .Module ):
34
26
"""
27
+ This class is deprecated. Its forward method will throw an error. On model load, we convert all instances of
28
+ `AttentionBlock` to `diffusers.models.attention_processor.Attention`, see
29
+ `ModelMixin#_convert_deprecated_attention_blocks`.
30
+
35
31
An attention block that allows spatial positions to attend to each other. Originally ported from here, but adapted
36
32
to the N-d case.
37
33
https://github.com/hojonathanho/diffusion/blob/1e0dceb3b3495bbe19116a5e1b3596cd0706c543/diffusion_tf/models/unet.py#L66.
@@ -46,8 +42,6 @@ class AttentionBlock(nn.Module):
46
42
eps (`float`, *optional*, defaults to 1e-5): The epsilon value to use for group norm.
47
43
"""
48
44
49
- # IMPORTANT;TODO(Patrick, William) - this class will be deprecated soon. Do not use it anymore
50
-
51
45
def __init__ (
52
46
self ,
53
47
channels : int ,
@@ -57,6 +51,16 @@ def __init__(
57
51
eps : float = 1e-5 ,
58
52
):
59
53
super ().__init__ ()
54
+
55
+ deprecation_message = (
56
+ "`AttentionBlock` has been deprecated and will be replaced with `diffusers.models.attention_processor.Attention`."
57
+ " The DiffusionPipeline loading this block in is auto converting it to `diffusers.models.attention_processor.Attention`."
58
+ " Please call `DiffusionPipeline#save_pretrained` and re-upload the pipeline to the hub."
59
+ " If you are only loading a model instead of a whole pipeline, the same instructions apply with `ModelMixin#save_pretrained`."
60
+ )
61
+
62
+ deprecate ("AttentionBlock" , "0.18.0" , deprecation_message , standard_warn = True )
63
+
60
64
self .channels = channels
61
65
62
66
self .num_heads = channels // num_head_channels if num_head_channels is not None else 1
@@ -71,107 +75,54 @@ def __init__(
71
75
self .rescale_output_factor = rescale_output_factor
72
76
self .proj_attn = nn .Linear (channels , channels , bias = True )
73
77
74
- self ._use_memory_efficient_attention_xformers = False
75
- self ._attention_op = None
76
-
77
- def reshape_heads_to_batch_dim (self , tensor ):
78
- batch_size , seq_len , dim = tensor .shape
79
- head_size = self .num_heads
80
- tensor = tensor .reshape (batch_size , seq_len , head_size , dim // head_size )
81
- tensor = tensor .permute (0 , 2 , 1 , 3 ).reshape (batch_size * head_size , seq_len , dim // head_size )
82
- return tensor
83
-
84
- def reshape_batch_dim_to_heads (self , tensor ):
85
- batch_size , seq_len , dim = tensor .shape
86
- head_size = self .num_heads
87
- tensor = tensor .reshape (batch_size // head_size , head_size , seq_len , dim )
88
- tensor = tensor .permute (0 , 2 , 1 , 3 ).reshape (batch_size // head_size , seq_len , dim * head_size )
89
- return tensor
90
-
91
78
def set_use_memory_efficient_attention_xformers (
92
79
self , use_memory_efficient_attention_xformers : bool , attention_op : Optional [Callable ] = None
93
80
):
94
- if use_memory_efficient_attention_xformers :
95
- if not is_xformers_available ():
96
- raise ModuleNotFoundError (
97
- (
98
- "Refer to https://github.com/facebookresearch/xformers for more information on how to install"
99
- " xformers"
100
- ),
101
- name = "xformers" ,
102
- )
103
- elif not torch .cuda .is_available ():
104
- raise ValueError (
105
- "torch.cuda.is_available() should be True but is False. xformers' memory efficient attention is"
106
- " only available for GPU "
107
- )
108
- else :
109
- try :
110
- # Make sure we can run the memory efficient attention
111
- _ = xformers .ops .memory_efficient_attention (
112
- torch .randn ((1 , 2 , 40 ), device = "cuda" ),
113
- torch .randn ((1 , 2 , 40 ), device = "cuda" ),
114
- torch .randn ((1 , 2 , 40 ), device = "cuda" ),
115
- )
116
- except Exception as e :
117
- raise e
118
- self ._use_memory_efficient_attention_xformers = use_memory_efficient_attention_xformers
119
- self ._attention_op = attention_op
81
+ raise ValueError (
82
+ "`AttentionBlock` should have been converted after load to `diffusers.models.attention_processor.Attention`"
83
+ )
120
84
121
85
def forward (self , hidden_states ):
122
- residual = hidden_states
123
- batch , channel , height , width = hidden_states .shape
124
-
125
- # norm
126
- hidden_states = self .group_norm (hidden_states )
127
-
128
- hidden_states = hidden_states .view (batch , channel , height * width ).transpose (1 , 2 )
129
-
130
- # proj to q, k, v
131
- query_proj = self .query (hidden_states )
132
- key_proj = self .key (hidden_states )
133
- value_proj = self .value (hidden_states )
134
-
135
- scale = 1 / math .sqrt (self .channels / self .num_heads )
136
-
137
- query_proj = self .reshape_heads_to_batch_dim (query_proj )
138
- key_proj = self .reshape_heads_to_batch_dim (key_proj )
139
- value_proj = self .reshape_heads_to_batch_dim (value_proj )
86
+ raise ValueError (
87
+ "`AttentionBlock` should have been converted after load to `diffusers.models.attention_processor.Attention`"
88
+ )
140
89
141
- if self ._use_memory_efficient_attention_xformers :
142
- # Memory efficient attention
143
- hidden_states = xformers .ops .memory_efficient_attention (
144
- query_proj , key_proj , value_proj , attn_bias = None , op = self ._attention_op
145
- )
146
- hidden_states = hidden_states .to (query_proj .dtype )
90
+ def _as_attention_processor_attention (self ):
91
+ if self .num_head_size is None :
92
+ # When `self.num_head_size` is None, there is a single attention head
93
+ # of all the channels
94
+ dim_head = self .channels
147
95
else :
148
- attention_scores = torch . baddbmm (
149
- torch . empty (
150
- query_proj . shape [ 0 ] ,
151
- query_proj . shape [ 1 ],
152
- key_proj . shape [ 1 ],
153
- dtype = query_proj . dtype ,
154
- device = query_proj . device ,
155
- ) ,
156
- query_proj ,
157
- key_proj . transpose ( - 1 , - 2 ) ,
158
- beta = 0 ,
159
- alpha = scale ,
160
- )
161
- attention_probs = torch . softmax ( attention_scores . float (), dim = - 1 ). type ( attention_scores . dtype )
162
- hidden_states = torch . bmm ( attention_probs , value_proj )
96
+ dim_head = self . num_head_size
97
+
98
+ # This will allocate some additional memory but as this is only done once during model load ,
99
+ # it should be ok.
100
+ attn = Attention (
101
+ self . channels ,
102
+ heads = self . num_heads ,
103
+ dim_head = dim_head ,
104
+ bias = True ,
105
+ upcast_softmax = True ,
106
+ norm_num_groups = self . group_norm . num_groups ,
107
+ processor = SpatialAttnProcessor () ,
108
+ eps = self . group_norm . eps ,
109
+ rescale_output_factor = self . rescale_output_factor ,
110
+ )
163
111
164
- # reshape hidden_states
165
- hidden_states = self .reshape_batch_dim_to_heads (hidden_states )
112
+ param = next (self .parameters ())
166
113
167
- # compute next hidden_states
168
- hidden_states = self . proj_attn ( hidden_states )
114
+ device = param . device
115
+ dtype = param . dtype
169
116
170
- hidden_states = hidden_states . transpose ( - 1 , - 2 ). reshape ( batch , channel , height , width )
117
+ attn . to ( device = device , dtype = dtype )
171
118
172
- # res connect and rescale
173
- hidden_states = (hidden_states + residual ) / self .rescale_output_factor
174
- return hidden_states
119
+ attn .group_norm .load_state_dict (self .group_norm .state_dict ())
120
+ attn .to_q .load_state_dict (self .query .state_dict ())
121
+ attn .to_k .load_state_dict (self .key .state_dict ())
122
+ attn .to_v .load_state_dict (self .value .state_dict ())
123
+ attn .to_out [0 ].load_state_dict (self .proj_attn .state_dict ())
124
+
125
+ return attn
175
126
176
127
177
128
class BasicTransformerBlock (nn .Module ):
0 commit comments