Skip to content

Commit 37d113c

Browse files
kashifpatil-surajwilliambermanpatrickvonplaten
authored
DiT Pipeline (#1806)
* added dit model * import * initial pipeline * initial convert script * initial pipeline * make style * raise valueerror * single function * rename classes * use DDIMScheduler * timesteps embedder * samples to cpu * fix var names * fix numpy type * use timesteps class for proj * fix typo * fix arg name * flip_sin_to_cos and better var names * fix C shape cal * make style * remove unused imports * cleanup * add back patch_size * initial dit doc * typo * Update docs/source/api/pipelines/dit.mdx Co-authored-by: Suraj Patil <[email protected]> * added copyright license headers * added example usage and toc * fix variable names asserts * remove comment * added docs * fix typo * upstream changes * set proper device for drop_ids * added initial dit pipeline test * update docs * fix imports * make fix-copies * isort * fix imports * get rid of more magic numbers * fix code when guidance is off * remove block_kwargs * cleanup script * removed to_2tuple * use FeedForward class instead of another MLP * style * work on mergint DiTBlock with BasicTransformerBlock * added missing final_dropout and args to BasicTransformerBlock * use norm from block * fix arg * remove unused arg * fix call to class_embedder * use timesteps * make style * attn_output gets multiplied * removed commented code * use Transformer2D * use self.is_input_patches * fix flags * fixed conversion to use Transformer2DModel * fixes for pipeline * remove dit.py * fix timesteps device * use randn_tensor and fix fp16 inf. * timesteps_emb already the right dtype * fix dit test class * fix test and style * fix norm2 usage in vq-diffusion * added author names to pipeline and lmagenet labels link * fix tests * use norm_type as string * rename dit to transformer * fix name * fix test * set norm_type = "layer" by default * fix tests * do not skip common tests * Update src/diffusers/models/attention.py Co-authored-by: Suraj Patil <[email protected]> * revert AdaLayerNorm API * fix norm_type name * make sure all components are in eval mode * revert norm2 API * compact * finish deprecation * add slow tests * remove @ * refactor some stuff * upload * Update src/diffusers/pipelines/dit/pipeline_dit.py * finish more * finish docs * improve docs * finish docs Co-authored-by: Suraj Patil <[email protected]> Co-authored-by: William Berman <[email protected]> Co-authored-by: Patrick von Platen <[email protected]>
1 parent 7e29b74 commit 37d113c

File tree

51 files changed

+995
-235
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

51 files changed

+995
-235
lines changed

docs/source/en/_toctree.yml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -106,6 +106,8 @@
106106
title: DDIM
107107
- local: api/pipelines/ddpm
108108
title: DDPM
109+
- local: api/pipelines/dit
110+
title: DiT
109111
- local: api/pipelines/latent_diffusion
110112
title: Latent Diffusion
111113
- local: api/pipelines/paint_by_example

docs/source/en/api/pipelines/dit.mdx

Lines changed: 59 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,59 @@
1+
<!--Copyright 2022 The HuggingFace Team. All rights reserved.
2+
3+
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
4+
the License. You may obtain a copy of the License at
5+
6+
http://www.apache.org/licenses/LICENSE-2.0
7+
8+
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
9+
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
10+
specific language governing permissions and limitations under the License.
11+
-->
12+
13+
# [Scalable Diffusion Models with Transformers](https://www.wpeebles.com/DiT) (DiT)
14+
15+
## Overview
16+
17+
[Scalable Diffusion Models with Transformers](https://arxiv.org/abs/2212.09748) (DiT) by William Peebles and Saining Xie.
18+
19+
The abstract of the paper is the following:
20+
21+
*We explore a new class of diffusion models based on the transformer architecture. We train latent diffusion models of images, replacing the commonly-used U-Net backbone with a transformer that operates on latent patches. We analyze the scalability of our Diffusion Transformers (DiTs) through the lens of forward pass complexity as measured by Gflops. We find that DiTs with higher Gflops -- through increased transformer depth/width or increased number of input tokens -- consistently have lower FID. In addition to possessing good scalability properties, our largest DiT-XL/2 models outperform all prior diffusion models on the class-conditional ImageNet 512x512 and 256x256 benchmarks, achieving a state-of-the-art FID of 2.27 on the latter.*
22+
23+
The original codebase of this paper can be found here: [facebookresearch/dit](https://github.com/facebookresearch/dit).
24+
25+
## Available Pipelines:
26+
27+
| Pipeline | Tasks | Colab
28+
|---|---|:---:|
29+
| [pipeline_dit.py](https://github.com/huggingface/diffusers/blob/main/src/diffusers/pipelines/dit/pipeline_dit.py) | *Conditional Image Generation* | - |
30+
31+
32+
## Usage example
33+
34+
```python
35+
from diffusers import DiTPipeline, DPMSolverMultistepScheduler
36+
import torch
37+
38+
pipe = DiTPipeline.from_pretrained("facebook/DiT-XL-2-256", torch_dtype=torch.float16)
39+
pipe.scheduler = DPMSolverMultistepScheduler.from_config(pipe.scheduler.config)
40+
pipe = pipe.to("cuda")
41+
42+
# pick words from Imagenet class labels
43+
pipe.labels # to print all available words
44+
45+
# pick words that exist in ImageNet
46+
words = ["white shark", "umbrella"]
47+
48+
class_ids = pipe.get_label_ids(words)
49+
50+
generator = torch.manual_seed(33)
51+
output = pipe(class_labels=class_ids, num_inference_steps=25, generator=generator)
52+
53+
image = output.images[0] # label 'white shark'
54+
```
55+
56+
## DiTPipeline
57+
[[autodoc]] DiTPipeline
58+
- all
59+
- __call__

docs/source/en/api/schedulers/overview.mdx

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,7 @@ To this end, the design of schedulers is such that:
3737

3838
- Schedulers can be used interchangeably between diffusion models in inference to find the preferred trade-off between speed and generation quality.
3939
- Schedulers are currently by default in PyTorch, but are designed to be framework independent (partial Jax support currently exists).
40+
- Many diffusion pipelines, such as [`StableDiffusionPipeline`] and [`DiTPipeline`] can use any of [`KarrasDiffusionSchedulers`]
4041

4142
## Schedulers Summary
4243

@@ -80,4 +81,6 @@ The class [`SchedulerOutput`] contains the outputs from any schedulers `step(...
8081

8182
[[autodoc]] schedulers.scheduling_utils.SchedulerOutput
8283

84+
### KarrasDiffusionSchedulers
8385

86+
[[autodoc]] schedulers.scheduling_utils.KarrasDiffusionSchedulers

scripts/convert_dit_to_diffusers.py

Lines changed: 162 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,162 @@
1+
import argparse
2+
import os
3+
4+
import torch
5+
6+
from diffusers import AutoencoderKL, DDIMScheduler, DiTPipeline, Transformer2DModel
7+
from torchvision.datasets.utils import download_url
8+
9+
10+
pretrained_models = {512: "DiT-XL-2-512x512.pt", 256: "DiT-XL-2-256x256.pt"}
11+
12+
13+
def download_model(model_name):
14+
"""
15+
Downloads a pre-trained DiT model from the web.
16+
"""
17+
local_path = f"pretrained_models/{model_name}"
18+
if not os.path.isfile(local_path):
19+
os.makedirs("pretrained_models", exist_ok=True)
20+
web_path = f"https://dl.fbaipublicfiles.com/DiT/models/{model_name}"
21+
download_url(web_path, "pretrained_models")
22+
model = torch.load(local_path, map_location=lambda storage, loc: storage)
23+
return model
24+
25+
26+
def main(args):
27+
state_dict = download_model(pretrained_models[args.image_size])
28+
29+
state_dict["pos_embed.proj.weight"] = state_dict["x_embedder.proj.weight"]
30+
state_dict["pos_embed.proj.bias"] = state_dict["x_embedder.proj.bias"]
31+
state_dict.pop("x_embedder.proj.weight")
32+
state_dict.pop("x_embedder.proj.bias")
33+
34+
for depth in range(28):
35+
state_dict[f"transformer_blocks.{depth}.norm1.emb.timestep_embedder.linear_1.weight"] = state_dict[
36+
"t_embedder.mlp.0.weight"
37+
]
38+
state_dict[f"transformer_blocks.{depth}.norm1.emb.timestep_embedder.linear_1.bias"] = state_dict[
39+
"t_embedder.mlp.0.bias"
40+
]
41+
state_dict[f"transformer_blocks.{depth}.norm1.emb.timestep_embedder.linear_2.weight"] = state_dict[
42+
"t_embedder.mlp.2.weight"
43+
]
44+
state_dict[f"transformer_blocks.{depth}.norm1.emb.timestep_embedder.linear_2.bias"] = state_dict[
45+
"t_embedder.mlp.2.bias"
46+
]
47+
state_dict[f"transformer_blocks.{depth}.norm1.emb.class_embedder.embedding_table.weight"] = state_dict[
48+
"y_embedder.embedding_table.weight"
49+
]
50+
51+
state_dict[f"transformer_blocks.{depth}.norm1.linear.weight"] = state_dict[
52+
f"blocks.{depth}.adaLN_modulation.1.weight"
53+
]
54+
state_dict[f"transformer_blocks.{depth}.norm1.linear.bias"] = state_dict[
55+
f"blocks.{depth}.adaLN_modulation.1.bias"
56+
]
57+
58+
q, k, v = torch.chunk(state_dict[f"blocks.{depth}.attn.qkv.weight"], 3, dim=0)
59+
q_bias, k_bias, v_bias = torch.chunk(state_dict[f"blocks.{depth}.attn.qkv.bias"], 3, dim=0)
60+
61+
state_dict[f"transformer_blocks.{depth}.attn1.to_q.weight"] = q
62+
state_dict[f"transformer_blocks.{depth}.attn1.to_q.bias"] = q_bias
63+
state_dict[f"transformer_blocks.{depth}.attn1.to_k.weight"] = k
64+
state_dict[f"transformer_blocks.{depth}.attn1.to_k.bias"] = k_bias
65+
state_dict[f"transformer_blocks.{depth}.attn1.to_v.weight"] = v
66+
state_dict[f"transformer_blocks.{depth}.attn1.to_v.bias"] = v_bias
67+
68+
state_dict[f"transformer_blocks.{depth}.attn1.to_out.0.weight"] = state_dict[
69+
f"blocks.{depth}.attn.proj.weight"
70+
]
71+
state_dict[f"transformer_blocks.{depth}.attn1.to_out.0.bias"] = state_dict[f"blocks.{depth}.attn.proj.bias"]
72+
73+
state_dict[f"transformer_blocks.{depth}.ff.net.0.proj.weight"] = state_dict[f"blocks.{depth}.mlp.fc1.weight"]
74+
state_dict[f"transformer_blocks.{depth}.ff.net.0.proj.bias"] = state_dict[f"blocks.{depth}.mlp.fc1.bias"]
75+
state_dict[f"transformer_blocks.{depth}.ff.net.2.weight"] = state_dict[f"blocks.{depth}.mlp.fc2.weight"]
76+
state_dict[f"transformer_blocks.{depth}.ff.net.2.bias"] = state_dict[f"blocks.{depth}.mlp.fc2.bias"]
77+
78+
state_dict.pop(f"blocks.{depth}.attn.qkv.weight")
79+
state_dict.pop(f"blocks.{depth}.attn.qkv.bias")
80+
state_dict.pop(f"blocks.{depth}.attn.proj.weight")
81+
state_dict.pop(f"blocks.{depth}.attn.proj.bias")
82+
state_dict.pop(f"blocks.{depth}.mlp.fc1.weight")
83+
state_dict.pop(f"blocks.{depth}.mlp.fc1.bias")
84+
state_dict.pop(f"blocks.{depth}.mlp.fc2.weight")
85+
state_dict.pop(f"blocks.{depth}.mlp.fc2.bias")
86+
state_dict.pop(f"blocks.{depth}.adaLN_modulation.1.weight")
87+
state_dict.pop(f"blocks.{depth}.adaLN_modulation.1.bias")
88+
89+
state_dict.pop("t_embedder.mlp.0.weight")
90+
state_dict.pop("t_embedder.mlp.0.bias")
91+
state_dict.pop("t_embedder.mlp.2.weight")
92+
state_dict.pop("t_embedder.mlp.2.bias")
93+
state_dict.pop("y_embedder.embedding_table.weight")
94+
95+
state_dict["proj_out_1.weight"] = state_dict["final_layer.adaLN_modulation.1.weight"]
96+
state_dict["proj_out_1.bias"] = state_dict["final_layer.adaLN_modulation.1.bias"]
97+
state_dict["proj_out_2.weight"] = state_dict["final_layer.linear.weight"]
98+
state_dict["proj_out_2.bias"] = state_dict["final_layer.linear.bias"]
99+
100+
state_dict.pop("final_layer.linear.weight")
101+
state_dict.pop("final_layer.linear.bias")
102+
state_dict.pop("final_layer.adaLN_modulation.1.weight")
103+
state_dict.pop("final_layer.adaLN_modulation.1.bias")
104+
105+
# DiT XL/2
106+
transformer = Transformer2DModel(
107+
sample_size=args.image_size // 8,
108+
num_layers=28,
109+
attention_head_dim=72,
110+
in_channels=4,
111+
out_channels=8,
112+
patch_size=2,
113+
attention_bias=True,
114+
num_attention_heads=16,
115+
activation_fn="gelu-approximate",
116+
num_embeds_ada_norm=1000,
117+
norm_type="ada_norm_zero",
118+
norm_elementwise_affine=False,
119+
)
120+
transformer.load_state_dict(state_dict, strict=True)
121+
122+
scheduler = DDIMScheduler(
123+
num_train_timesteps=1000,
124+
beta_schedule="linear",
125+
prediction_type="epsilon",
126+
clip_sample=False,
127+
)
128+
129+
vae = AutoencoderKL.from_pretrained(args.vae_model)
130+
131+
pipeline = DiTPipeline(transformer=transformer, vae=vae, scheduler=scheduler)
132+
133+
if args.save:
134+
pipeline.save_pretrained(args.checkpoint_path)
135+
136+
137+
if __name__ == "__main__":
138+
parser = argparse.ArgumentParser()
139+
140+
parser.add_argument(
141+
"--image_size",
142+
default=256,
143+
type=int,
144+
required=False,
145+
help="Image size of pretrained model, either 256 or 512.",
146+
)
147+
parser.add_argument(
148+
"--vae_model",
149+
default="stabilityai/sd-vae-ft-ema",
150+
type=str,
151+
required=False,
152+
help="Path to pretrained VAE model, either stabilityai/sd-vae-ft-mse or stabilityai/sd-vae-ft-ema.",
153+
)
154+
parser.add_argument(
155+
"--save", default=True, type=bool, required=False, help="Whether to save the converted pipeline or not."
156+
)
157+
parser.add_argument(
158+
"--checkpoint_path", default=None, type=str, required=True, help="Path to the output pipeline."
159+
)
160+
161+
args = parser.parse_args()
162+
main(args)

src/diffusers/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -57,6 +57,7 @@
5757
DDIMPipeline,
5858
DDPMPipeline,
5959
DiffusionPipeline,
60+
DiTPipeline,
6061
ImagePipelineOutput,
6162
KarrasVePipeline,
6263
LDMPipeline,

0 commit comments

Comments
 (0)