Skip to content

Commit e74c173

Browse files
ayushtuesayushmangalyiyixuxu
authored
[WIP] Add Kandinsky decoder (#3330)
* Add movq Co-authored-by: ayushmangal <[email protected]> Co-authored-by: YiYi Xu <[email protected]>
1 parent 3a5fbf8 commit e74c173

File tree

7 files changed

+539
-35
lines changed

7 files changed

+539
-35
lines changed

Diff for: scripts/convert_kandinsky_to_diffusers.py

+415-5
Large diffs are not rendered by default.

Diff for: src/diffusers/models/attention.py

+21
Original file line numberDiff line numberDiff line change
@@ -369,3 +369,24 @@ def forward(self, x, emb):
369369
x = F.group_norm(x, self.num_groups, eps=self.eps)
370370
x = x * (1 + scale) + shift
371371
return x
372+
373+
class SpatialNorm(nn.Module):
374+
"""
375+
Spatially conditioned normalization as defined in https://arxiv.org/abs/2209.09002
376+
"""
377+
def __init__(
378+
self,
379+
f_channels,
380+
zq_channels,
381+
):
382+
super().__init__()
383+
self.norm_layer = nn.GroupNorm(num_channels=f_channels,num_groups=32,eps=1e-6,affine=True)
384+
self.conv_y = nn.Conv2d(zq_channels, f_channels, kernel_size=1, stride=1, padding=0)
385+
self.conv_b = nn.Conv2d(zq_channels, f_channels, kernel_size=1, stride=1, padding=0)
386+
387+
def forward(self, f, zq):
388+
f_size = f.shape[-2:]
389+
zq = F.interpolate(zq, size=f_size, mode="nearest")
390+
norm_f = self.norm_layer(f)
391+
new_f = norm_f * self.conv_y(zq) + self.conv_b(zq)
392+
return new_f

Diff for: src/diffusers/models/resnet.py

+9-5
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@
2020
import torch.nn as nn
2121
import torch.nn.functional as F
2222

23-
from .attention import AdaGroupNorm
23+
from .attention import AdaGroupNorm, SpatialNorm
2424

2525

2626
class Upsample1D(nn.Module):
@@ -460,7 +460,7 @@ def __init__(
460460
eps=1e-6,
461461
non_linearity="swish",
462462
skip_time_act=False,
463-
time_embedding_norm="default", # default, scale_shift, ada_group
463+
time_embedding_norm="default", # default, scale_shift, ada_group, spatial
464464
kernel=None,
465465
output_scale_factor=1.0,
466466
use_in_shortcut=None,
@@ -487,6 +487,8 @@ def __init__(
487487

488488
if self.time_embedding_norm == "ada_group":
489489
self.norm1 = AdaGroupNorm(temb_channels, in_channels, groups, eps=eps)
490+
elif self.time_embedding_norm == "spatial":
491+
self.norm1 = SpatialNorm(in_channels, temb_channels)
490492
else:
491493
self.norm1 = torch.nn.GroupNorm(num_groups=groups, num_channels=in_channels, eps=eps, affine=True)
492494

@@ -497,7 +499,7 @@ def __init__(
497499
self.time_emb_proj = torch.nn.Linear(temb_channels, out_channels)
498500
elif self.time_embedding_norm == "scale_shift":
499501
self.time_emb_proj = torch.nn.Linear(temb_channels, 2 * out_channels)
500-
elif self.time_embedding_norm == "ada_group":
502+
elif self.time_embedding_norm == "ada_group" or self.time_embedding_norm == "spatial":
501503
self.time_emb_proj = None
502504
else:
503505
raise ValueError(f"unknown time_embedding_norm : {self.time_embedding_norm} ")
@@ -506,6 +508,8 @@ def __init__(
506508

507509
if self.time_embedding_norm == "ada_group":
508510
self.norm2 = AdaGroupNorm(temb_channels, out_channels, groups_out, eps=eps)
511+
elif self.time_embedding_norm == "spatial":
512+
self.norm2 = SpatialNorm(out_channels, temb_channels)
509513
else:
510514
self.norm2 = torch.nn.GroupNorm(num_groups=groups_out, num_channels=out_channels, eps=eps, affine=True)
511515

@@ -551,7 +555,7 @@ def __init__(
551555
def forward(self, input_tensor, temb):
552556
hidden_states = input_tensor
553557

554-
if self.time_embedding_norm == "ada_group":
558+
if self.time_embedding_norm == "ada_group" or self.time_embedding_norm == "spatial":
555559
hidden_states = self.norm1(hidden_states, temb)
556560
else:
557561
hidden_states = self.norm1(hidden_states)
@@ -579,7 +583,7 @@ def forward(self, input_tensor, temb):
579583
if temb is not None and self.time_embedding_norm == "default":
580584
hidden_states = hidden_states + temb
581585

582-
if self.time_embedding_norm == "ada_group":
586+
if self.time_embedding_norm == "ada_group" or self.time_embedding_norm == "spatial":
583587
hidden_states = self.norm2(hidden_states, temb)
584588
else:
585589
hidden_states = self.norm2(hidden_states)

Diff for: src/diffusers/models/unet_2d_blocks.py

+39-12
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@
1818
import torch.nn.functional as F
1919
from torch import nn
2020

21-
from .attention import AdaGroupNorm
21+
from .attention import AdaGroupNorm, AttentionBlock, SpatialNorm
2222
from .attention_processor import Attention, AttnAddedKVProcessor, AttnAddedKVProcessor2_0
2323
from .dual_transformer_2d import DualTransformer2DModel
2424
from .resnet import Downsample2D, FirDownsample2D, FirUpsample2D, KDownsample2D, KUpsample2D, ResnetBlock2D, Upsample2D
@@ -348,6 +348,7 @@ def get_up_block(
348348
resnet_act_fn=resnet_act_fn,
349349
resnet_groups=resnet_groups,
350350
resnet_time_scale_shift=resnet_time_scale_shift,
351+
temb_channels=temb_channels
351352
)
352353
elif up_block_type == "AttnUpDecoderBlock2D":
353354
return AttnUpDecoderBlock2D(
@@ -360,6 +361,7 @@ def get_up_block(
360361
resnet_groups=resnet_groups,
361362
attn_num_head_channels=attn_num_head_channels,
362363
resnet_time_scale_shift=resnet_time_scale_shift,
364+
temb_channels=temb_channels
363365
)
364366
elif up_block_type == "KUpBlock2D":
365367
return KUpBlock2D(
@@ -406,7 +408,6 @@ def __init__(
406408
super().__init__()
407409
resnet_groups = resnet_groups if resnet_groups is not None else min(in_channels // 4, 32)
408410
self.add_attention = add_attention
409-
410411
# there is always at least one resnet
411412
resnets = [
412413
ResnetBlock2D(
@@ -439,7 +440,6 @@ def __init__(
439440
upcast_softmax=True,
440441
_from_deprecated_attn_block=True,
441442
)
442-
)
443443
else:
444444
attentions.append(None)
445445

@@ -465,7 +465,8 @@ def forward(self, hidden_states, temb=None):
465465
hidden_states = self.resnets[0](hidden_states, temb)
466466
for attn, resnet in zip(self.attentions, self.resnets[1:]):
467467
if attn is not None:
468-
hidden_states = attn(hidden_states)
468+
hidden_states = attn(hidden_states, temb)
469+
469470
hidden_states = resnet(hidden_states, temb)
470471

471472
return hidden_states
@@ -1971,6 +1972,30 @@ def custom_forward(*inputs):
19711972
return hidden_states
19721973

19731974

1975+
class MOVQAttention(nn.Module):
1976+
def __init__(self, query_dim, temb_channels, attn_num_head_channels):
1977+
super().__init__()
1978+
1979+
self.norm = SpatialNorm(query_dim, temb_channels)
1980+
num_heads = query_dim // attn_num_head_channels if attn_num_head_channels is not None else 1
1981+
dim_head = attn_num_head_channels if attn_num_head_channels is not None else query_dim
1982+
self.attention = Attention(
1983+
query_dim=query_dim,
1984+
heads=num_heads,
1985+
dim_head=dim_head,
1986+
bias=True
1987+
)
1988+
1989+
def forward(self, hidden_states, temb):
1990+
residual = hidden_states
1991+
hidden_states = self.norm(hidden_states, temb).view(hidden_states.shape[0], hidden_states.shape[1], -1)
1992+
hidden_states = self.attention(hidden_states.transpose(1, 2), None, None).transpose(1, 2)
1993+
hidden_states = hidden_states.view(residual.shape)
1994+
hidden_states = hidden_states + residual
1995+
return hidden_states
1996+
1997+
1998+
19741999
class UpDecoderBlock2D(nn.Module):
19752000
def __init__(
19762001
self,
@@ -1985,6 +2010,7 @@ def __init__(
19852010
resnet_pre_norm: bool = True,
19862011
output_scale_factor=1.0,
19872012
add_upsample=True,
2013+
temb_channels=None
19882014
):
19892015
super().__init__()
19902016
resnets = []
@@ -1996,7 +2022,7 @@ def __init__(
19962022
ResnetBlock2D(
19972023
in_channels=input_channels,
19982024
out_channels=out_channels,
1999-
temb_channels=None,
2025+
temb_channels=temb_channels,
20002026
eps=resnet_eps,
20012027
groups=resnet_groups,
20022028
dropout=dropout,
@@ -2014,9 +2040,9 @@ def __init__(
20142040
else:
20152041
self.upsamplers = None
20162042

2017-
def forward(self, hidden_states):
2043+
def forward(self, hidden_states, temb=None):
20182044
for resnet in self.resnets:
2019-
hidden_states = resnet(hidden_states, temb=None)
2045+
hidden_states = resnet(hidden_states, temb=temb)
20202046

20212047
if self.upsamplers is not None:
20222048
for upsampler in self.upsamplers:
@@ -2040,6 +2066,7 @@ def __init__(
20402066
attn_num_head_channels=1,
20412067
output_scale_factor=1.0,
20422068
add_upsample=True,
2069+
temb_channels=None
20432070
):
20442071
super().__init__()
20452072
resnets = []
@@ -2052,7 +2079,7 @@ def __init__(
20522079
ResnetBlock2D(
20532080
in_channels=input_channels,
20542081
out_channels=out_channels,
2055-
temb_channels=None,
2082+
temb_channels=temb_channels,
20562083
eps=resnet_eps,
20572084
groups=resnet_groups,
20582085
dropout=dropout,
@@ -2075,7 +2102,6 @@ def __init__(
20752102
upcast_softmax=True,
20762103
_from_deprecated_attn_block=True,
20772104
)
2078-
)
20792105

20802106
self.attentions = nn.ModuleList(attentions)
20812107
self.resnets = nn.ModuleList(resnets)
@@ -2085,10 +2111,10 @@ def __init__(
20852111
else:
20862112
self.upsamplers = None
20872113

2088-
def forward(self, hidden_states):
2114+
def forward(self, hidden_states, temb=None):
20892115
for resnet, attn in zip(self.resnets, self.attentions):
2090-
hidden_states = resnet(hidden_states, temb=None)
2091-
hidden_states = attn(hidden_states)
2116+
hidden_states = resnet(hidden_states, temb=temb)
2117+
hidden_states = attn(hidden_states, temb)
20922118

20932119
if self.upsamplers is not None:
20942120
for upsampler in self.upsamplers:
@@ -2847,3 +2873,4 @@ def forward(
28472873
hidden_states = attn_output + hidden_states
28482874

28492875
return hidden_states
2876+

Diff for: src/diffusers/models/vae.py

+20-9
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@
2020

2121
from ..utils import BaseOutput, randn_tensor
2222
from .unet_2d_blocks import UNetMidBlock2D, get_down_block, get_up_block
23-
23+
from .attention import SpatialNorm
2424

2525
@dataclass
2626
class DecoderOutput(BaseOutput):
@@ -149,6 +149,7 @@ def __init__(
149149
layers_per_block=2,
150150
norm_num_groups=32,
151151
act_fn="silu",
152+
norm_type="default", # default, spatial
152153
):
153154
super().__init__()
154155
self.layers_per_block = layers_per_block
@@ -164,16 +165,19 @@ def __init__(
164165
self.mid_block = None
165166
self.up_blocks = nn.ModuleList([])
166167

168+
169+
temb_channels = in_channels if norm_type == "spatial" else None
170+
167171
# mid
168172
self.mid_block = UNetMidBlock2D(
169173
in_channels=block_out_channels[-1],
170174
resnet_eps=1e-6,
171175
resnet_act_fn=act_fn,
172176
output_scale_factor=1,
173-
resnet_time_scale_shift="default",
177+
resnet_time_scale_shift=norm_type,
174178
attn_num_head_channels=None,
175179
resnet_groups=norm_num_groups,
176-
temb_channels=None,
180+
temb_channels=temb_channels,
177181
)
178182

179183
# up
@@ -196,19 +200,23 @@ def __init__(
196200
resnet_act_fn=act_fn,
197201
resnet_groups=norm_num_groups,
198202
attn_num_head_channels=None,
199-
temb_channels=None,
203+
temb_channels=temb_channels,
204+
resnet_time_scale_shift=norm_type,
200205
)
201206
self.up_blocks.append(up_block)
202207
prev_output_channel = output_channel
203208

204209
# out
205-
self.conv_norm_out = nn.GroupNorm(num_channels=block_out_channels[0], num_groups=norm_num_groups, eps=1e-6)
210+
if norm_type == "spatial":
211+
self.conv_norm_out = SpatialNorm(block_out_channels[0], temb_channels)
212+
else:
213+
self.conv_norm_out = nn.GroupNorm(num_channels=block_out_channels[0], num_groups=norm_num_groups, eps=1e-6)
206214
self.conv_act = nn.SiLU()
207215
self.conv_out = nn.Conv2d(block_out_channels[0], out_channels, 3, padding=1)
208216

209217
self.gradient_checkpointing = False
210218

211-
def forward(self, z):
219+
def forward(self, z, zq=None):
212220
sample = z
213221
sample = self.conv_in(sample)
214222

@@ -230,15 +238,18 @@ def custom_forward(*inputs):
230238
sample = torch.utils.checkpoint.checkpoint(create_custom_forward(up_block), sample)
231239
else:
232240
# middle
233-
sample = self.mid_block(sample)
241+
sample = self.mid_block(sample, zq)
234242
sample = sample.to(upscale_dtype)
235243

236244
# up
237245
for up_block in self.up_blocks:
238-
sample = up_block(sample)
246+
sample = up_block(sample, zq)
239247

240248
# post-process
241-
sample = self.conv_norm_out(sample)
249+
if zq is None:
250+
sample = self.conv_norm_out(sample)
251+
else:
252+
sample = self.conv_norm_out(sample, zq)
242253
sample = self.conv_act(sample)
243254
sample = self.conv_out(sample)
244255

Diff for: src/diffusers/models/vq_model.py

+5-2
Original file line numberDiff line numberDiff line change
@@ -82,9 +82,11 @@ def __init__(
8282
norm_num_groups: int = 32,
8383
vq_embed_dim: Optional[int] = None,
8484
scaling_factor: float = 0.18215,
85+
norm_type: str = "default"
8586
):
8687
super().__init__()
8788

89+
8890
# pass init params to Encoder
8991
self.encoder = Encoder(
9092
in_channels=in_channels,
@@ -112,6 +114,7 @@ def __init__(
112114
layers_per_block=layers_per_block,
113115
act_fn=act_fn,
114116
norm_num_groups=norm_num_groups,
117+
norm_type=norm_type,
115118
)
116119

117120
def encode(self, x: torch.FloatTensor, return_dict: bool = True) -> VQEncoderOutput:
@@ -131,8 +134,8 @@ def decode(
131134
quant, emb_loss, info = self.quantize(h)
132135
else:
133136
quant = h
134-
quant = self.post_quant_conv(quant)
135-
dec = self.decoder(quant)
137+
quant2 = self.post_quant_conv(quant)
138+
dec = self.decoder(quant2, quant if self.config.norm_type == "spatial" else None)
136139

137140
if not return_dict:
138141
return (dec,)

0 commit comments

Comments
 (0)