Skip to content

Commit 2261510

Browse files
sayakpaulyiyixuxu
andauthored
[Core] Add AuraFlow (#8796)
* add lavender flow transformer --------- Co-authored-by: YiYi Xu <[email protected]>
1 parent 87b9db6 commit 2261510

18 files changed

+1459
-27
lines changed
Lines changed: 131 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,131 @@
1+
import argparse
2+
3+
import torch
4+
from huggingface_hub import hf_hub_download
5+
6+
from diffusers.models.transformers.auraflow_transformer_2d import AuraFlowTransformer2DModel
7+
8+
9+
def load_original_state_dict(args):
10+
model_pt = hf_hub_download(repo_id=args.original_state_dict_repo_id, filename="aura_diffusion_pytorch_model.bin")
11+
state_dict = torch.load(model_pt, map_location="cpu")
12+
return state_dict
13+
14+
15+
def calculate_layers(state_dict_keys, key_prefix):
16+
dit_layers = set()
17+
for k in state_dict_keys:
18+
if key_prefix in k:
19+
dit_layers.add(int(k.split(".")[2]))
20+
print(f"{key_prefix}: {len(dit_layers)}")
21+
return len(dit_layers)
22+
23+
24+
# similar to SD3 but only for the last norm layer
25+
def swap_scale_shift(weight, dim):
26+
shift, scale = weight.chunk(2, dim=0)
27+
new_weight = torch.cat([scale, shift], dim=0)
28+
return new_weight
29+
30+
31+
def convert_transformer(state_dict):
32+
converted_state_dict = {}
33+
state_dict_keys = list(state_dict.keys())
34+
35+
converted_state_dict["register_tokens"] = state_dict.pop("model.register_tokens")
36+
converted_state_dict["pos_embed.pos_embed"] = state_dict.pop("model.positional_encoding")
37+
converted_state_dict["pos_embed.proj.weight"] = state_dict.pop("model.init_x_linear.weight")
38+
converted_state_dict["pos_embed.proj.bias"] = state_dict.pop("model.init_x_linear.bias")
39+
40+
converted_state_dict["time_step_proj.linear_1.weight"] = state_dict.pop("model.t_embedder.mlp.0.weight")
41+
converted_state_dict["time_step_proj.linear_1.bias"] = state_dict.pop("model.t_embedder.mlp.0.bias")
42+
converted_state_dict["time_step_proj.linear_2.weight"] = state_dict.pop("model.t_embedder.mlp.2.weight")
43+
converted_state_dict["time_step_proj.linear_2.bias"] = state_dict.pop("model.t_embedder.mlp.2.bias")
44+
45+
converted_state_dict["context_embedder.weight"] = state_dict.pop("model.cond_seq_linear.weight")
46+
47+
mmdit_layers = calculate_layers(state_dict_keys, key_prefix="double_layers")
48+
single_dit_layers = calculate_layers(state_dict_keys, key_prefix="single_layers")
49+
50+
# MMDiT blocks 🎸.
51+
for i in range(mmdit_layers):
52+
# feed-forward
53+
path_mapping = {"mlpX": "ff", "mlpC": "ff_context"}
54+
weight_mapping = {"c_fc1": "linear_1", "c_fc2": "linear_2", "c_proj": "out_projection"}
55+
for orig_k, diffuser_k in path_mapping.items():
56+
for k, v in weight_mapping.items():
57+
converted_state_dict[f"joint_transformer_blocks.{i}.{diffuser_k}.{v}.weight"] = state_dict.pop(
58+
f"model.double_layers.{i}.{orig_k}.{k}.weight"
59+
)
60+
61+
# norms
62+
path_mapping = {"modX": "norm1", "modC": "norm1_context"}
63+
for orig_k, diffuser_k in path_mapping.items():
64+
converted_state_dict[f"joint_transformer_blocks.{i}.{diffuser_k}.linear.weight"] = state_dict.pop(
65+
f"model.double_layers.{i}.{orig_k}.1.weight"
66+
)
67+
68+
# attns
69+
x_attn_mapping = {"w2q": "to_q", "w2k": "to_k", "w2v": "to_v", "w2o": "to_out.0"}
70+
context_attn_mapping = {"w1q": "add_q_proj", "w1k": "add_k_proj", "w1v": "add_v_proj", "w1o": "to_add_out"}
71+
for attn_mapping in [x_attn_mapping, context_attn_mapping]:
72+
for k, v in attn_mapping.items():
73+
converted_state_dict[f"joint_transformer_blocks.{i}.attn.{v}.weight"] = state_dict.pop(
74+
f"model.double_layers.{i}.attn.{k}.weight"
75+
)
76+
77+
# Single-DiT blocks.
78+
for i in range(single_dit_layers):
79+
# feed-forward
80+
mapping = {"c_fc1": "linear_1", "c_fc2": "linear_2", "c_proj": "out_projection"}
81+
for k, v in mapping.items():
82+
converted_state_dict[f"single_transformer_blocks.{i}.ff.{v}.weight"] = state_dict.pop(
83+
f"model.single_layers.{i}.mlp.{k}.weight"
84+
)
85+
86+
# norms
87+
converted_state_dict[f"single_transformer_blocks.{i}.norm1.linear.weight"] = state_dict.pop(
88+
f"model.single_layers.{i}.modCX.1.weight"
89+
)
90+
91+
# attns
92+
x_attn_mapping = {"w1q": "to_q", "w1k": "to_k", "w1v": "to_v", "w1o": "to_out.0"}
93+
for k, v in x_attn_mapping.items():
94+
converted_state_dict[f"single_transformer_blocks.{i}.attn.{v}.weight"] = state_dict.pop(
95+
f"model.single_layers.{i}.attn.{k}.weight"
96+
)
97+
98+
# Final blocks.
99+
converted_state_dict["proj_out.weight"] = state_dict.pop("model.final_linear.weight")
100+
converted_state_dict["norm_out.linear.weight"] = swap_scale_shift(state_dict.pop("model.modF.1.weight"), dim=None)
101+
102+
return converted_state_dict
103+
104+
105+
@torch.no_grad()
106+
def populate_state_dict(args):
107+
original_state_dict = load_original_state_dict(args)
108+
state_dict_keys = list(original_state_dict.keys())
109+
mmdit_layers = calculate_layers(state_dict_keys, key_prefix="double_layers")
110+
single_dit_layers = calculate_layers(state_dict_keys, key_prefix="single_layers")
111+
112+
converted_state_dict = convert_transformer(original_state_dict)
113+
model_diffusers = AuraFlowTransformer2DModel(
114+
num_mmdit_layers=mmdit_layers, num_single_dit_layers=single_dit_layers
115+
)
116+
model_diffusers.load_state_dict(converted_state_dict, strict=True)
117+
118+
return model_diffusers
119+
120+
121+
if __name__ == "__main__":
122+
parser = argparse.ArgumentParser()
123+
parser.add_argument("--original_state_dict_repo_id", default="AuraDiffusion/auradiffusion-v0.1a0", type=str)
124+
parser.add_argument("--dump_path", default="aura-flow", type=str)
125+
parser.add_argument("--hub_id", default=None, type=str)
126+
args = parser.parse_args()
127+
128+
model_diffusers = populate_state_dict(args)
129+
model_diffusers.save_pretrained(args.dump_path)
130+
if args.hub_id is not None:
131+
model_diffusers.push_to_hub(args.hub_id)

src/diffusers/__init__.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -76,6 +76,7 @@
7676
_import_structure["models"].extend(
7777
[
7878
"AsymmetricAutoencoderKL",
79+
"AuraFlowTransformer2DModel",
7980
"AutoencoderKL",
8081
"AutoencoderKLTemporalDecoder",
8182
"AutoencoderTiny",
@@ -235,6 +236,7 @@
235236
"AudioLDM2ProjectionModel",
236237
"AudioLDM2UNet2DConditionModel",
237238
"AudioLDMPipeline",
239+
"AuraFlowPipeline",
238240
"BlipDiffusionControlNetPipeline",
239241
"BlipDiffusionPipeline",
240242
"ChatGLMModel",
@@ -507,6 +509,7 @@
507509
else:
508510
from .models import (
509511
AsymmetricAutoencoderKL,
512+
AuraFlowTransformer2DModel,
510513
AutoencoderKL,
511514
AutoencoderKLTemporalDecoder,
512515
AutoencoderTiny,
@@ -646,6 +649,7 @@
646649
AudioLDM2ProjectionModel,
647650
AudioLDM2UNet2DConditionModel,
648651
AudioLDMPipeline,
652+
AuraFlowPipeline,
649653
ChatGLMModel,
650654
ChatGLMTokenizer,
651655
CLIPImageProjection,

src/diffusers/models/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,7 @@
3838
_import_structure["controlnet_xs"] = ["ControlNetXSAdapter", "UNetControlNetXSModel"]
3939
_import_structure["embeddings"] = ["ImageProjection"]
4040
_import_structure["modeling_utils"] = ["ModelMixin"]
41+
_import_structure["transformers.auraflow_transformer_2d"] = ["AuraFlowTransformer2DModel"]
4142
_import_structure["transformers.dit_transformer_2d"] = ["DiTTransformer2DModel"]
4243
_import_structure["transformers.dual_transformer_2d"] = ["DualTransformer2DModel"]
4344
_import_structure["transformers.hunyuan_transformer_2d"] = ["HunyuanDiT2DModel"]
@@ -84,6 +85,7 @@
8485
from .embeddings import ImageProjection
8586
from .modeling_utils import ModelMixin
8687
from .transformers import (
88+
AuraFlowTransformer2DModel,
8789
DiTTransformer2DModel,
8890
DualTransformer2DModel,
8991
HunyuanDiT2DModel,

src/diffusers/models/attention_processor.py

Lines changed: 114 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@
2222
from ..image_processor import IPAdapterMaskProcessor
2323
from ..utils import deprecate, logging
2424
from ..utils.import_utils import is_torch_npu_available, is_xformers_available
25-
from ..utils.torch_utils import maybe_allow_in_graph
25+
from ..utils.torch_utils import is_torch_version, maybe_allow_in_graph
2626

2727

2828
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
@@ -104,6 +104,7 @@ def __init__(
104104
cross_attention_norm_num_groups: int = 32,
105105
qk_norm: Optional[str] = None,
106106
added_kv_proj_dim: Optional[int] = None,
107+
added_proj_bias: Optional[bool] = True,
107108
norm_num_groups: Optional[int] = None,
108109
spatial_norm_dim: Optional[int] = None,
109110
out_bias: bool = True,
@@ -118,6 +119,10 @@ def __init__(
118119
context_pre_only=None,
119120
):
120121
super().__init__()
122+
123+
# To prevent circular import.
124+
from .normalization import FP32LayerNorm
125+
121126
self.inner_dim = out_dim if out_dim is not None else dim_head * heads
122127
self.inner_kv_dim = self.inner_dim if kv_heads is None else dim_head * kv_heads
123128
self.query_dim = query_dim
@@ -170,6 +175,9 @@ def __init__(
170175
elif qk_norm == "layer_norm":
171176
self.norm_q = nn.LayerNorm(dim_head, eps=eps)
172177
self.norm_k = nn.LayerNorm(dim_head, eps=eps)
178+
elif qk_norm == "fp32_layer_norm":
179+
self.norm_q = FP32LayerNorm(dim_head, elementwise_affine=False, bias=False, eps=eps)
180+
self.norm_k = FP32LayerNorm(dim_head, elementwise_affine=False, bias=False, eps=eps)
173181
elif qk_norm == "layer_norm_across_heads":
174182
# Lumina applys qk norm across all heads
175183
self.norm_q = nn.LayerNorm(dim_head * heads, eps=eps)
@@ -211,10 +219,10 @@ def __init__(
211219
self.to_v = None
212220

213221
if self.added_kv_proj_dim is not None:
214-
self.add_k_proj = nn.Linear(added_kv_proj_dim, self.inner_kv_dim)
215-
self.add_v_proj = nn.Linear(added_kv_proj_dim, self.inner_kv_dim)
222+
self.add_k_proj = nn.Linear(added_kv_proj_dim, self.inner_kv_dim, bias=added_proj_bias)
223+
self.add_v_proj = nn.Linear(added_kv_proj_dim, self.inner_kv_dim, bias=added_proj_bias)
216224
if self.context_pre_only is not None:
217-
self.add_q_proj = nn.Linear(added_kv_proj_dim, self.inner_dim)
225+
self.add_q_proj = nn.Linear(added_kv_proj_dim, self.inner_dim, bias=added_proj_bias)
218226

219227
self.to_out = nn.ModuleList([])
220228
self.to_out.append(nn.Linear(self.inner_dim, self.out_dim, bias=out_bias))
@@ -223,6 +231,14 @@ def __init__(
223231
if self.context_pre_only is not None and not self.context_pre_only:
224232
self.to_add_out = nn.Linear(self.inner_dim, self.out_dim, bias=out_bias)
225233

234+
if qk_norm is not None and added_kv_proj_dim is not None:
235+
if qk_norm == "fp32_layer_norm":
236+
self.norm_added_q = FP32LayerNorm(dim_head, elementwise_affine=False, bias=False, eps=eps)
237+
self.norm_added_k = FP32LayerNorm(dim_head, elementwise_affine=False, bias=False, eps=eps)
238+
else:
239+
self.norm_added_q = None
240+
self.norm_added_k = None
241+
226242
# set attention processor
227243
# We use the AttnProcessor2_0 by default when torch 2.x is used which uses
228244
# torch.nn.functional.scaled_dot_product_attention for native Flash/memory_efficient_attention
@@ -1137,6 +1153,100 @@ def __call__(
11371153
return hidden_states, encoder_hidden_states
11381154

11391155

1156+
class AuraFlowAttnProcessor2_0:
1157+
"""Attention processor used typically in processing Aura Flow."""
1158+
1159+
def __init__(self):
1160+
if not hasattr(F, "scaled_dot_product_attention") and is_torch_version("<", "2.1"):
1161+
raise ImportError(
1162+
"AuraFlowAttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to at least 2.1 or above as we use `scale` in `F.scaled_dot_product_attention()`. "
1163+
)
1164+
1165+
def __call__(
1166+
self,
1167+
attn: Attention,
1168+
hidden_states: torch.FloatTensor,
1169+
encoder_hidden_states: torch.FloatTensor = None,
1170+
i=0,
1171+
*args,
1172+
**kwargs,
1173+
) -> torch.FloatTensor:
1174+
batch_size = hidden_states.shape[0]
1175+
1176+
# `sample` projections.
1177+
query = attn.to_q(hidden_states)
1178+
key = attn.to_k(hidden_states)
1179+
value = attn.to_v(hidden_states)
1180+
1181+
# `context` projections.
1182+
if encoder_hidden_states is not None:
1183+
encoder_hidden_states_query_proj = attn.add_q_proj(encoder_hidden_states)
1184+
encoder_hidden_states_key_proj = attn.add_k_proj(encoder_hidden_states)
1185+
encoder_hidden_states_value_proj = attn.add_v_proj(encoder_hidden_states)
1186+
1187+
# Reshape.
1188+
inner_dim = key.shape[-1]
1189+
head_dim = inner_dim // attn.heads
1190+
query = query.view(batch_size, -1, attn.heads, head_dim)
1191+
key = key.view(batch_size, -1, attn.heads, head_dim)
1192+
value = value.view(batch_size, -1, attn.heads, head_dim)
1193+
1194+
# Apply QK norm.
1195+
if attn.norm_q is not None:
1196+
query = attn.norm_q(query)
1197+
if attn.norm_k is not None:
1198+
key = attn.norm_k(key)
1199+
1200+
# Concatenate the projections.
1201+
if encoder_hidden_states is not None:
1202+
encoder_hidden_states_query_proj = encoder_hidden_states_query_proj.view(
1203+
batch_size, -1, attn.heads, head_dim
1204+
)
1205+
encoder_hidden_states_key_proj = encoder_hidden_states_key_proj.view(batch_size, -1, attn.heads, head_dim)
1206+
encoder_hidden_states_value_proj = encoder_hidden_states_value_proj.view(
1207+
batch_size, -1, attn.heads, head_dim
1208+
)
1209+
1210+
if attn.norm_added_q is not None:
1211+
encoder_hidden_states_query_proj = attn.norm_added_q(encoder_hidden_states_query_proj)
1212+
if attn.norm_added_k is not None:
1213+
encoder_hidden_states_key_proj = attn.norm_added_q(encoder_hidden_states_key_proj)
1214+
1215+
query = torch.cat([encoder_hidden_states_query_proj, query], dim=1)
1216+
key = torch.cat([encoder_hidden_states_key_proj, key], dim=1)
1217+
value = torch.cat([encoder_hidden_states_value_proj, value], dim=1)
1218+
1219+
query = query.transpose(1, 2)
1220+
key = key.transpose(1, 2)
1221+
value = value.transpose(1, 2)
1222+
1223+
# Attention.
1224+
hidden_states = F.scaled_dot_product_attention(
1225+
query, key, value, dropout_p=0.0, scale=attn.scale, is_causal=False
1226+
)
1227+
hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
1228+
hidden_states = hidden_states.to(query.dtype)
1229+
1230+
# Split the attention outputs.
1231+
if encoder_hidden_states is not None:
1232+
hidden_states, encoder_hidden_states = (
1233+
hidden_states[:, encoder_hidden_states.shape[1] :],
1234+
hidden_states[:, : encoder_hidden_states.shape[1]],
1235+
)
1236+
1237+
# linear proj
1238+
hidden_states = attn.to_out[0](hidden_states)
1239+
# dropout
1240+
hidden_states = attn.to_out[1](hidden_states)
1241+
if encoder_hidden_states is not None:
1242+
encoder_hidden_states = attn.to_add_out(encoder_hidden_states)
1243+
1244+
if encoder_hidden_states is not None:
1245+
return hidden_states, encoder_hidden_states
1246+
else:
1247+
return hidden_states
1248+
1249+
11401250
class XFormersAttnAddedKVProcessor:
11411251
r"""
11421252
Processor for implementing memory efficient attention using xFormers.

src/diffusers/models/embeddings.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -473,18 +473,20 @@ def forward(self, sample, condition=None):
473473

474474

475475
class Timesteps(nn.Module):
476-
def __init__(self, num_channels: int, flip_sin_to_cos: bool, downscale_freq_shift: float):
476+
def __init__(self, num_channels: int, flip_sin_to_cos: bool, downscale_freq_shift: float, scale: int = 1):
477477
super().__init__()
478478
self.num_channels = num_channels
479479
self.flip_sin_to_cos = flip_sin_to_cos
480480
self.downscale_freq_shift = downscale_freq_shift
481+
self.scale = scale
481482

482483
def forward(self, timesteps):
483484
t_emb = get_timestep_embedding(
484485
timesteps,
485486
self.num_channels,
486487
flip_sin_to_cos=self.flip_sin_to_cos,
487488
downscale_freq_shift=self.downscale_freq_shift,
489+
scale=self.scale,
488490
)
489491
return t_emb
490492

0 commit comments

Comments
 (0)