-
Notifications
You must be signed in to change notification settings - Fork 5.9k
[WIP] Add Kandinsky decoder #3330
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
The documentation is not available anymore as the PR was closed or merged. |
looks super good! thanks! |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@ayushtues
great job!! I left a few comments. I will ask @patrickvonplaten or @williamberman give a review too.
we can wait to make changes after getting their feedback :)
src/diffusers/models/attention.py
Outdated
@@ -55,12 +55,18 @@ def __init__( | |||
norm_num_groups: int = 32, | |||
rescale_output_factor: float = 1.0, | |||
eps: float = 1e-5, | |||
use_spatial_norm: bool = False, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This class is deprecated and should not be used anymore cc @williamberman big time we remove it ;-)
Let's please try to solve this in the other attention class
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@yiyixuxu Can you help to instead make this work with the Attention
class in attention_processor.py
?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
So instead of using existing UNetMidBlock2D
, AttnUpDecoderBlock2D
(which currently use the deprecated class), we should:
- write new decoder blocks to use the
Attention
class instead, think it will be very easy, just add an option here https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py#L60 - once
UNetMidBlock2D
,AttnUpDecoderBlock2D
are refactored withAttention
class, we consolidate them
is this plan ok?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We will need to use AttnAddedKVProcessor
since to replicate the attention used in Kandinksy, we need residual connections and a norm on the hiddens states before passing for q-k-v calculations (both of which are only present in AttnAddedKVProcessor
, but it has this concatenation thing going on in the k-v calculation here, which needs to be tackled.
This is the attention implementation in the original repo for reference:
h_ = x
h_ = self.norm(h_, zq)
q = self.q(h_)
k = self.k(h_)
v = self.v(h_)
# compute attention
b, c, h, w = q.shape
q = q.reshape(b, c, h * w)
q = q.permute(0, 2, 1) # b,hw,c
k = k.reshape(b, c, h * w) # b,c,hw
w_ = torch.bmm(q, k) # b,hw,hw w[b,i,j]=sum_c q[b,i,c]k[b,c,j]
w_ = w_ * (int(c) ** (-0.5))
w_ = torch.nn.functional.softmax(w_, dim=2)
# attend to values
v = v.reshape(b, c, h * w)
w_ = w_.permute(0, 2, 1) # b,hw,hw (first hw of k, second of q)
h_ = torch.bmm(v, w_) # b, c,hw (hw of q) h_[b,c,j] = sum_i v[b,c,i] w_[b,i,j]
h_ = h_.reshape(b, c, h, w)
h_ = self.proj_out(h_)
return x + h_
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@ayushtues ohh I don't think we should use AttnAddedKVProcessor
since we don't have these additional projection layers for k and v as the processor name indicated 😁
for attention processors, I think we can either
- create a new attention processor, e.g.
MOVQAttnProcessor
- or make it work with
AttnProcessor
(https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py#L381), we need to:- add an argument to allow skip-connection
- add the norm layer, maybe before this line, we can add https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py#L393
if self.group_norm: hidden_states = self.group_norm(hidden_states)
we also need slightly refactor theAttention
class here (https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py#L92) so the group_norm
layer can be spatial_norm
I would wait for @patrickvonplaten to clarify how we should approach this. It's a little bit tricky IMO because it's overlapping a larger effort here to completely replace the deprecated AttentionBlock
class
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
To me, this sounds like a better approach:
create a new attention processor, e.g. MOVQAttnProcessor
(with the new Attention
class)
I think this helps to eliminate any overlap between this PR and the larger effort for refactoring the attention blocks.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Okay makes sense, will add a new processor for this then, thanks!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@ayushtues great! let me know once you need another review:) we need to add tests too once everything looks good
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I instead added a MOVQAttention
in unet_2d_blocks
here https://github.com/ayushtues/diffusers/blob/kandinsky_decoder/src/diffusers/models/unet_2d_blocks.py#L1957, which uses the basic AttentionProcessor
, and handles all the other processing needed in it, so that we need to make no changes in either attention.py
or attention_processor.py
, something similar seemed to be done here in the past: https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/t5_film_transformer.py#L185
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Let me know how this looks, and maybe we can move this block elsewhere if needed
Thanks @yiyixuxu @patrickvonplaten for the detailed review, will have a look and make the changes soon! |
Great ❤️ Let us know when you need another review @ayushtues |
Looking forward to this! |
@yiyixuxu was able to use the decoder implementation in diffusers and generate images, colab here : https://colab.research.google.com/drive/1jhMcNi9k3xkkuDHZh6jgY0MvsSODWden?usp=sharing Let me integrate this into the pipeline next |
Exciting! |
Integrated decoder into text2img pipeline, the below works! Colab : https://colab.research.google.com/drive/1jhMcNi9k3xkkuDHZh6jgY0MvsSODWden#scrollTo=vRwq-F6Q-mjv from diffusers import KandinskyPipeline, KandinskyPriorPipeline
from transformers import AutoTokenizer
import torch
import numpy as np
device = "cuda"
# # inputs
prompt= "red cat, 4k photo"
batch_size=1
# # create prior
pipe_prior = KandinskyPriorPipeline.from_pretrained("YiYiXu/Kandinsky-prior")
pipe_prior.to("cuda")
# use prior to generate image_emb based on our prompt
generator = torch.Generator(device=device).manual_seed(0)
image_emb = pipe_prior(prompt, generator=generator,)
zero_image_emb = pipe_prior("")
pipe = KandinskyPipeline.from_pretrained("ayushtues/test-kandinksy")
pipe.to(device)
generator = torch.Generator(device="cuda").manual_seed(0)
samples = pipe(
prompt,
image_embeds=image_emb,
negative_image_embeds =zero_image_emb,
height=768,
width=768,
num_inference_steps=100,
generator=generator )
samples[0].save("k_image_d_test.png") |
@yiyixuxu maybe you can give it a review now |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for addressing our feedback and great job overall!
However, I think we need to do the attention differently -it's going to be pretty straightforward now as this PR just got merged #3387
I left some comments but I think the easiest way is for me to merge this into my PR and change it from there. We can review it together afterward :)
In the meantime, feel free to start on the img2img!
@@ -82,9 +82,11 @@ def __init__( | |||
norm_num_groups: int = 32, | |||
vq_embed_dim: Optional[int] = None, | |||
scaling_factor: float = 0.18215, | |||
norm_type: str = "default" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
norm_type: str = "default" | |
norm_type: str = "default" # "default", "spatial" |
@@ -426,15 +427,23 @@ def __init__( | |||
|
|||
for _ in range(num_layers): | |||
if self.add_attention: | |||
attentions.append( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think we shouldn't change this class -
@@ -44,6 +45,21 @@ def get_new_h_w(h, w): | |||
new_w += 1 | |||
return new_h * 8, new_w * 8 | |||
|
|||
def process_images(batch): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
we don't need this function, let's do post-processing similar to what's done here https://github.com/huggingface/diffusers/blob/kandinsky/src/diffusers/pipelines/kandinsky/pipeline_kandinsky_inpaint.py#L452
@@ -94,6 +114,13 @@ def prepare_latents(self, shape, dtype, device, generator, latents, scheduler): | |||
latents = latents * scheduler.init_noise_sigma | |||
return latents | |||
|
|||
def get_image(self, latents): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
same, don't need to create a separate method
@@ -371,4 +398,5 @@ def __call__( | |||
|
|||
_, latents = latents.chunk(2) | |||
|
|||
return latents | |||
images = self.get_image(latents) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
let's do post-processing similar to what's done here https://github.com/huggingface/diffusers/blob/kandinsky/src/diffusers/pipelines/kandinsky/pipeline_kandinsky_inpaint.py#L452
@@ -1945,6 +1955,30 @@ def custom_forward(*inputs): | |||
return hidden_states | |||
|
|||
|
|||
class MOVQAttention(nn.Module): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think we don't need to create a new class anymore now this PR is merged #3387
This reverts commit e74c173.
Adds MOVQ based decoder for Kandinsky 2.1, part of #3308
to-do: