Skip to content

Commit 78a6eed

Browse files
Add bit diffusion [WIP] (#971)
* Create bit_diffusion.py Bit diffusion based on the paper, arXiv:2208.04202, Chen2022AnalogBG * adding bit diffusion to new branch ran tests * tests * tests * tests * tests * removed test folders + added to README * Update README.md Co-authored-by: Patrick von Platen <[email protected]>
1 parent 94b27fb commit 78a6eed

File tree

2 files changed

+273
-1
lines changed

2 files changed

+273
-1
lines changed

examples/community/README.md

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@ If a community doesn't work as expected, please open an issue and ping the autho
2121
| Multilingual Stable Diffusion| Stable Diffusion Pipeline that supports prompts in 50 different languages. | [Multilingual Stable Diffusion](#multilingual-stable-diffusion-pipeline) | - | [Juan Carlos Piñeros](https://github.com/juancopi81) |
2222
| Image to Image Inpainting Stable Diffusion | Stable Diffusion Pipeline that enables the overlaying of two images and subsequent inpainting| [Image to Image Inpainting Stable Diffusion](#image-to-image-inpainting-stable-diffusion) | - | [Alex McKinney](https://github.com/vvvm23) |
2323
| Text Based Inpainting Stable Diffusion | Stable Diffusion Inpainting Pipeline that enables passing a text prompt to generate the mask for inpainting| [Text Based Inpainting Stable Diffusion](#image-to-image-inpainting-stable-diffusion) | - | [Dhruv Karan](https://github.com/unography) |
24+
| Bit Diffusion | Diffusion on discrete data | [Bit Diffusion](#bit-diffusion) | - |[Stuti R.](https://github.com/kingstut) |
2425

2526

2627

@@ -343,7 +344,6 @@ out = pipe(
343344
)
344345
```
345346

346-
347347
### Composable Stable diffusion
348348

349349
[Composable Stable Diffusion](https://energy-based-model.github.io/Compositional-Visual-Generation-with-Composable-Diffusion-Models/) proposes conjunction and negation (negative prompts) operators for compositional generation with conditional diffusion models.
@@ -655,3 +655,12 @@ prompt = "a cup" # the masked out region will be replaced with this
655655
with autocast("cuda"):
656656
image = pipe(image=image, text=text, prompt=prompt).images[0]
657657
```
658+
659+
### Bit Diffusion
660+
Based https://arxiv.org/abs/2208.04202, this is used for diffusion on discrete data - eg, discreate image data, DNA sequence data. An unconditional discreate image can be generated like this:
661+
662+
```python
663+
from diffusers import DiffusionPipeline
664+
pipe = DiffusionPipeline.from_pretrained("google/ddpm-cifar10-32", custom_pipeline="bit_diffusion")
665+
image = pipe().images[0]
666+
```

examples/community/bit_diffusion.py

Lines changed: 263 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,263 @@
1+
from typing import Optional, Tuple, Union
2+
3+
import torch
4+
5+
from diffusers import DDIMScheduler, DDPMScheduler, DiffusionPipeline, UNet2DConditionModel
6+
from diffusers.pipeline_utils import ImagePipelineOutput
7+
from diffusers.schedulers.scheduling_ddim import DDIMSchedulerOutput
8+
from diffusers.schedulers.scheduling_ddpm import DDPMSchedulerOutput
9+
from einops import rearrange, reduce
10+
11+
12+
BITS = 8
13+
14+
15+
# convert to bit representations and back taken from https://github.com/lucidrains/bit-diffusion/blob/main/bit_diffusion/bit_diffusion.py
16+
def decimal_to_bits(x, bits=BITS):
17+
"""expects image tensor ranging from 0 to 1, outputs bit tensor ranging from -1 to 1"""
18+
device = x.device
19+
20+
x = (x * 255).int().clamp(0, 255)
21+
22+
mask = 2 ** torch.arange(bits - 1, -1, -1, device=device)
23+
mask = rearrange(mask, "d -> d 1 1")
24+
x = rearrange(x, "b c h w -> b c 1 h w")
25+
26+
bits = ((x & mask) != 0).float()
27+
bits = rearrange(bits, "b c d h w -> b (c d) h w")
28+
bits = bits * 2 - 1
29+
return bits
30+
31+
32+
def bits_to_decimal(x, bits=BITS):
33+
"""expects bits from -1 to 1, outputs image tensor from 0 to 1"""
34+
device = x.device
35+
36+
x = (x > 0).int()
37+
mask = 2 ** torch.arange(bits - 1, -1, -1, device=device, dtype=torch.int32)
38+
39+
mask = rearrange(mask, "d -> d 1 1")
40+
x = rearrange(x, "b (c d) h w -> b c d h w", d=8)
41+
dec = reduce(x * mask, "b c d h w -> b c h w", "sum")
42+
return (dec / 255).clamp(0.0, 1.0)
43+
44+
45+
# modified scheduler step functions for clamping the predicted x_0 between -bit_scale and +bit_scale
46+
def ddim_bit_scheduler_step(
47+
self,
48+
model_output: torch.FloatTensor,
49+
timestep: int,
50+
sample: torch.FloatTensor,
51+
eta: float = 0.0,
52+
use_clipped_model_output: bool = True,
53+
generator=None,
54+
return_dict: bool = True,
55+
) -> Union[DDIMSchedulerOutput, Tuple]:
56+
"""
57+
Predict the sample at the previous timestep by reversing the SDE. Core function to propagate the diffusion
58+
process from the learned model outputs (most often the predicted noise).
59+
Args:
60+
model_output (`torch.FloatTensor`): direct output from learned diffusion model.
61+
timestep (`int`): current discrete timestep in the diffusion chain.
62+
sample (`torch.FloatTensor`):
63+
current instance of sample being created by diffusion process.
64+
eta (`float`): weight of noise for added noise in diffusion step.
65+
use_clipped_model_output (`bool`): TODO
66+
generator: random number generator.
67+
return_dict (`bool`): option for returning tuple rather than DDIMSchedulerOutput class
68+
Returns:
69+
[`~schedulers.scheduling_utils.DDIMSchedulerOutput`] or `tuple`:
70+
[`~schedulers.scheduling_utils.DDIMSchedulerOutput`] if `return_dict` is True, otherwise a `tuple`. When
71+
returning a tuple, the first element is the sample tensor.
72+
"""
73+
if self.num_inference_steps is None:
74+
raise ValueError(
75+
"Number of inference steps is 'None', you need to run 'set_timesteps' after creating the scheduler"
76+
)
77+
78+
# See formulas (12) and (16) of DDIM paper https://arxiv.org/pdf/2010.02502.pdf
79+
# Ideally, read DDIM paper in-detail understanding
80+
81+
# Notation (<variable name> -> <name in paper>
82+
# - pred_noise_t -> e_theta(x_t, t)
83+
# - pred_original_sample -> f_theta(x_t, t) or x_0
84+
# - std_dev_t -> sigma_t
85+
# - eta -> η
86+
# - pred_sample_direction -> "direction pointing to x_t"
87+
# - pred_prev_sample -> "x_t-1"
88+
89+
# 1. get previous step value (=t-1)
90+
prev_timestep = timestep - self.config.num_train_timesteps // self.num_inference_steps
91+
92+
# 2. compute alphas, betas
93+
alpha_prod_t = self.alphas_cumprod[timestep]
94+
alpha_prod_t_prev = self.alphas_cumprod[prev_timestep] if prev_timestep >= 0 else self.final_alpha_cumprod
95+
96+
beta_prod_t = 1 - alpha_prod_t
97+
98+
# 3. compute predicted original sample from predicted noise also called
99+
# "predicted x_0" of formula (12) from https://arxiv.org/pdf/2010.02502.pdf
100+
pred_original_sample = (sample - beta_prod_t ** (0.5) * model_output) / alpha_prod_t ** (0.5)
101+
102+
# 4. Clip "predicted x_0"
103+
scale = self.bit_scale
104+
if self.config.clip_sample:
105+
pred_original_sample = torch.clamp(pred_original_sample, -scale, scale)
106+
107+
# 5. compute variance: "sigma_t(η)" -> see formula (16)
108+
# σ_t = sqrt((1 − α_t−1)/(1 − α_t)) * sqrt(1 − α_t/α_t−1)
109+
variance = self._get_variance(timestep, prev_timestep)
110+
std_dev_t = eta * variance ** (0.5)
111+
112+
if use_clipped_model_output:
113+
# the model_output is always re-derived from the clipped x_0 in Glide
114+
model_output = (sample - alpha_prod_t ** (0.5) * pred_original_sample) / beta_prod_t ** (0.5)
115+
116+
# 6. compute "direction pointing to x_t" of formula (12) from https://arxiv.org/pdf/2010.02502.pdf
117+
pred_sample_direction = (1 - alpha_prod_t_prev - std_dev_t**2) ** (0.5) * model_output
118+
119+
# 7. compute x_t without "random noise" of formula (12) from https://arxiv.org/pdf/2010.02502.pdf
120+
prev_sample = alpha_prod_t_prev ** (0.5) * pred_original_sample + pred_sample_direction
121+
122+
if eta > 0:
123+
# randn_like does not support generator https://github.com/pytorch/pytorch/issues/27072
124+
device = model_output.device if torch.is_tensor(model_output) else "cpu"
125+
noise = torch.randn(model_output.shape, dtype=model_output.dtype, generator=generator).to(device)
126+
variance = self._get_variance(timestep, prev_timestep) ** (0.5) * eta * noise
127+
128+
prev_sample = prev_sample + variance
129+
130+
if not return_dict:
131+
return (prev_sample,)
132+
133+
return DDIMSchedulerOutput(prev_sample=prev_sample, pred_original_sample=pred_original_sample)
134+
135+
136+
def ddpm_bit_scheduler_step(
137+
self,
138+
model_output: torch.FloatTensor,
139+
timestep: int,
140+
sample: torch.FloatTensor,
141+
predict_epsilon=True,
142+
generator=None,
143+
return_dict: bool = True,
144+
) -> Union[DDPMSchedulerOutput, Tuple]:
145+
"""
146+
Predict the sample at the previous timestep by reversing the SDE. Core function to propagate the diffusion
147+
process from the learned model outputs (most often the predicted noise).
148+
Args:
149+
model_output (`torch.FloatTensor`): direct output from learned diffusion model.
150+
timestep (`int`): current discrete timestep in the diffusion chain.
151+
sample (`torch.FloatTensor`):
152+
current instance of sample being created by diffusion process.
153+
predict_epsilon (`bool`):
154+
optional flag to use when model predicts the samples directly instead of the noise, epsilon.
155+
generator: random number generator.
156+
return_dict (`bool`): option for returning tuple rather than DDPMSchedulerOutput class
157+
Returns:
158+
[`~schedulers.scheduling_utils.DDPMSchedulerOutput`] or `tuple`:
159+
[`~schedulers.scheduling_utils.DDPMSchedulerOutput`] if `return_dict` is True, otherwise a `tuple`. When
160+
returning a tuple, the first element is the sample tensor.
161+
"""
162+
t = timestep
163+
164+
if model_output.shape[1] == sample.shape[1] * 2 and self.variance_type in ["learned", "learned_range"]:
165+
model_output, predicted_variance = torch.split(model_output, sample.shape[1], dim=1)
166+
else:
167+
predicted_variance = None
168+
169+
# 1. compute alphas, betas
170+
alpha_prod_t = self.alphas_cumprod[t]
171+
alpha_prod_t_prev = self.alphas_cumprod[t - 1] if t > 0 else self.one
172+
beta_prod_t = 1 - alpha_prod_t
173+
beta_prod_t_prev = 1 - alpha_prod_t_prev
174+
175+
# 2. compute predicted original sample from predicted noise also called
176+
# "predicted x_0" of formula (15) from https://arxiv.org/pdf/2006.11239.pdf
177+
if predict_epsilon:
178+
pred_original_sample = (sample - beta_prod_t ** (0.5) * model_output) / alpha_prod_t ** (0.5)
179+
else:
180+
pred_original_sample = model_output
181+
182+
# 3. Clip "predicted x_0"
183+
scale = self.bit_scale
184+
if self.config.clip_sample:
185+
pred_original_sample = torch.clamp(pred_original_sample, -scale, scale)
186+
187+
# 4. Compute coefficients for pred_original_sample x_0 and current sample x_t
188+
# See formula (7) from https://arxiv.org/pdf/2006.11239.pdf
189+
pred_original_sample_coeff = (alpha_prod_t_prev ** (0.5) * self.betas[t]) / beta_prod_t
190+
current_sample_coeff = self.alphas[t] ** (0.5) * beta_prod_t_prev / beta_prod_t
191+
192+
# 5. Compute predicted previous sample µ_t
193+
# See formula (7) from https://arxiv.org/pdf/2006.11239.pdf
194+
pred_prev_sample = pred_original_sample_coeff * pred_original_sample + current_sample_coeff * sample
195+
196+
# 6. Add noise
197+
variance = 0
198+
if t > 0:
199+
noise = torch.randn(
200+
model_output.size(), dtype=model_output.dtype, layout=model_output.layout, generator=generator
201+
).to(model_output.device)
202+
variance = (self._get_variance(t, predicted_variance=predicted_variance) ** 0.5) * noise
203+
204+
pred_prev_sample = pred_prev_sample + variance
205+
206+
if not return_dict:
207+
return (pred_prev_sample,)
208+
209+
return DDPMSchedulerOutput(prev_sample=pred_prev_sample, pred_original_sample=pred_original_sample)
210+
211+
212+
class BitDiffusion(DiffusionPipeline):
213+
def __init__(
214+
self,
215+
unet: UNet2DConditionModel,
216+
scheduler: Union[DDIMScheduler, DDPMScheduler],
217+
bit_scale: Optional[float] = 1.0,
218+
):
219+
super().__init__()
220+
self.bit_scale = bit_scale
221+
self.scheduler.step = (
222+
ddim_bit_scheduler_step if isinstance(scheduler, DDIMScheduler) else ddpm_bit_scheduler_step
223+
)
224+
225+
self.register_modules(unet=unet, scheduler=scheduler)
226+
227+
@torch.no_grad()
228+
def __call__(
229+
self,
230+
height: Optional[int] = 256,
231+
width: Optional[int] = 256,
232+
num_inference_steps: Optional[int] = 50,
233+
generator: Optional[torch.Generator] = None,
234+
batch_size: Optional[int] = 1,
235+
output_type: Optional[str] = "pil",
236+
return_dict: bool = True,
237+
**kwargs,
238+
) -> Union[Tuple, ImagePipelineOutput]:
239+
latents = torch.randn(
240+
(batch_size, self.unet.in_channels, height, width),
241+
generator=generator,
242+
)
243+
latents = decimal_to_bits(latents) * self.bit_scale
244+
latents = latents.to(self.device)
245+
246+
self.scheduler.set_timesteps(num_inference_steps)
247+
248+
for t in self.progress_bar(self.scheduler.timesteps):
249+
# predict the noise residual
250+
noise_pred = self.unet(latents, t).sample
251+
252+
# compute the previous noisy sample x_t -> x_t-1
253+
latents = self.scheduler.step(noise_pred, t, latents).prev_sample
254+
255+
image = bits_to_decimal(latents)
256+
257+
if output_type == "pil":
258+
image = self.numpy_to_pil(image)
259+
260+
if not return_dict:
261+
return (image,)
262+
263+
return ImagePipelineOutput(images=image)

0 commit comments

Comments
 (0)