Skip to content

[WIP] Add Kandinsky decoder #3330

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 11 commits into from
May 12, 2023
Merged

Conversation

ayushtues
Copy link
Contributor

@ayushtues ayushtues commented May 4, 2023

Adds MOVQ based decoder for Kandinsky 2.1, part of #3308

to-do:

  • load pretrained weights from original repo
  • ensure forward passes result in same output
  • Integrate with the pipeline

@HuggingFaceDocBuilderDev
Copy link

HuggingFaceDocBuilderDev commented May 4, 2023

The documentation is not available anymore as the PR was closed or merged.

@ayushtues ayushtues mentioned this pull request May 4, 2023
11 tasks
@yiyixuxu
Copy link
Collaborator

yiyixuxu commented May 4, 2023

looks super good! thanks!
will do a review in detail later today

Copy link
Collaborator

@yiyixuxu yiyixuxu left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@ayushtues
great job!! I left a few comments. I will ask @patrickvonplaten or @williamberman give a review too.
we can wait to make changes after getting their feedback :)

@@ -55,12 +55,18 @@ def __init__(
norm_num_groups: int = 32,
rescale_output_factor: float = 1.0,
eps: float = 1e-5,
use_spatial_norm: bool = False,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This class is deprecated and should not be used anymore cc @williamberman big time we remove it ;-)

Let's please try to solve this in the other attention class

Copy link
Contributor

@patrickvonplaten patrickvonplaten May 5, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@yiyixuxu Can you help to instead make this work with the Attention class in attention_processor.py?

Copy link
Collaborator

@yiyixuxu yiyixuxu May 5, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@patrickvonplaten

So instead of using existing UNetMidBlock2D, AttnUpDecoderBlock2D (which currently use the deprecated class), we should:

  1. write new decoder blocks to use the Attention class instead, think it will be very easy, just add an option here https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py#L60
  2. once UNetMidBlock2D, AttnUpDecoderBlock2D are refactored with Attention class, we consolidate them

is this plan ok?

Copy link
Contributor Author

@ayushtues ayushtues May 6, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We will need to use AttnAddedKVProcessor since to replicate the attention used in Kandinksy, we need residual connections and a norm on the hiddens states before passing for q-k-v calculations (both of which are only present in AttnAddedKVProcessor, but it has this concatenation thing going on in the k-v calculation here, which needs to be tackled.

This is the attention implementation in the original repo for reference:

        h_ = x
        h_ = self.norm(h_, zq)
        q = self.q(h_)
        k = self.k(h_)
        v = self.v(h_)

        # compute attention
        b, c, h, w = q.shape
        q = q.reshape(b, c, h * w)
        q = q.permute(0, 2, 1)  # b,hw,c
        k = k.reshape(b, c, h * w)  # b,c,hw
        w_ = torch.bmm(q, k)  # b,hw,hw    w[b,i,j]=sum_c q[b,i,c]k[b,c,j]
        w_ = w_ * (int(c) ** (-0.5))
        w_ = torch.nn.functional.softmax(w_, dim=2)

        # attend to values
        v = v.reshape(b, c, h * w)
        w_ = w_.permute(0, 2, 1)  # b,hw,hw (first hw of k, second of q)
        h_ = torch.bmm(v, w_)  # b, c,hw (hw of q) h_[b,c,j] = sum_i v[b,c,i] w_[b,i,j]
        h_ = h_.reshape(b, c, h, w)

        h_ = self.proj_out(h_)

        return x + h_

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@ayushtues ohh I don't think we should use AttnAddedKVProcessor since we don't have these additional projection layers for k and v as the processor name indicated 😁

for attention processors, I think we can either

  1. create a new attention processor, e.g. MOVQAttnProcessor
  2. or make it work with AttnProcessor (https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py#L381), we need to:
      if self.group_norm:
          hidden_states = self.group_norm(hidden_states)

we also need slightly refactor theAttention class here (https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py#L92) so the group_norm layer can be spatial_norm

I would wait for @patrickvonplaten to clarify how we should approach this. It's a little bit tricky IMO because it's overlapping a larger effort here to completely replace the deprecated AttentionBlock class

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

To me, this sounds like a better approach:

create a new attention processor, e.g. MOVQAttnProcessor

(with the new Attention class)

I think this helps to eliminate any overlap between this PR and the larger effort for refactoring the attention blocks.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Okay makes sense, will add a new processor for this then, thanks!

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@ayushtues great! let me know once you need another review:) we need to add tests too once everything looks good

Copy link
Contributor Author

@ayushtues ayushtues May 10, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I instead added a MOVQAttention in unet_2d_blocks here https://github.com/ayushtues/diffusers/blob/kandinsky_decoder/src/diffusers/models/unet_2d_blocks.py#L1957, which uses the basic AttentionProcessor, and handles all the other processing needed in it, so that we need to make no changes in either attention.py or attention_processor.py, something similar seemed to be done here in the past: https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/t5_film_transformer.py#L185

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let me know how this looks, and maybe we can move this block elsewhere if needed

@ayushtues
Copy link
Contributor Author

Thanks @yiyixuxu @patrickvonplaten for the detailed review, will have a look and make the changes soon!

@patrickvonplaten
Copy link
Contributor

Great ❤️ Let us know when you need another review @ayushtues

@JincanDeng
Copy link

Looking forward to this!

@ayushtues
Copy link
Contributor Author

ayushtues commented May 12, 2023

@yiyixuxu was able to use the decoder implementation in diffusers and generate images, colab here : https://colab.research.google.com/drive/1jhMcNi9k3xkkuDHZh6jgY0MvsSODWden?usp=sharing

Let me integrate this into the pipeline next

@patrickvonplaten
Copy link
Contributor

Exciting!

@ayushtues
Copy link
Contributor Author

ayushtues commented May 12, 2023

Integrated decoder into text2img pipeline, the below works! Colab : https://colab.research.google.com/drive/1jhMcNi9k3xkkuDHZh6jgY0MvsSODWden#scrollTo=vRwq-F6Q-mjv

from diffusers import KandinskyPipeline, KandinskyPriorPipeline
from transformers import AutoTokenizer

import torch
import numpy as np

device = "cuda"

# # inputs
prompt= "red cat, 4k photo"
batch_size=1 


# # create prior 
pipe_prior = KandinskyPriorPipeline.from_pretrained("YiYiXu/Kandinsky-prior")
pipe_prior.to("cuda")

# use prior to generate image_emb based on our prompt
generator = torch.Generator(device=device).manual_seed(0)
image_emb = pipe_prior(prompt, generator=generator,)
zero_image_emb = pipe_prior("")

pipe = KandinskyPipeline.from_pretrained("ayushtues/test-kandinksy")
pipe.to(device)


generator = torch.Generator(device="cuda").manual_seed(0)
samples = pipe(
    prompt,
    image_embeds=image_emb,
    negative_image_embeds =zero_image_emb,
    height=768,
    width=768,
    num_inference_steps=100,
    generator=generator )

samples[0].save("k_image_d_test.png")

image

@ayushtues
Copy link
Contributor Author

@yiyixuxu maybe you can give it a review now

Copy link
Collaborator

@yiyixuxu yiyixuxu left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for addressing our feedback and great job overall!

However, I think we need to do the attention differently -it's going to be pretty straightforward now as this PR just got merged #3387

I left some comments but I think the easiest way is for me to merge this into my PR and change it from there. We can review it together afterward :)

In the meantime, feel free to start on the img2img!

@@ -82,9 +82,11 @@ def __init__(
norm_num_groups: int = 32,
vq_embed_dim: Optional[int] = None,
scaling_factor: float = 0.18215,
norm_type: str = "default"
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
norm_type: str = "default"
norm_type: str = "default" # "default", "spatial"

@@ -426,15 +427,23 @@ def __init__(

for _ in range(num_layers):
if self.add_attention:
attentions.append(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we shouldn't change this class -

@@ -44,6 +45,21 @@ def get_new_h_w(h, w):
new_w += 1
return new_h * 8, new_w * 8

def process_images(batch):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@@ -94,6 +114,13 @@ def prepare_latents(self, shape, dtype, device, generator, latents, scheduler):
latents = latents * scheduler.init_noise_sigma
return latents

def get_image(self, latents):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

same, don't need to create a separate method

@@ -371,4 +398,5 @@ def __call__(

_, latents = latents.chunk(2)

return latents
images = self.get_image(latents)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@@ -1945,6 +1955,30 @@ def custom_forward(*inputs):
return hidden_states


class MOVQAttention(nn.Module):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we don't need to create a new class anymore now this PR is merged #3387

@yiyixuxu yiyixuxu merged commit e74c173 into huggingface:kandinsky May 12, 2023
yiyixuxu pushed a commit that referenced this pull request May 12, 2023
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

6 participants