Skip to content

Commit 2ba42aa

Browse files
authored
[Community Pipeline] MagicMix (#1839)
* initial * type hints * update scheduler type hint * add to README * add example generation to README * v -> mix_factor * load scheduler from pretrained
1 parent 53c8147 commit 2ba42aa

File tree

2 files changed

+198
-1
lines changed

2 files changed

+198
-1
lines changed

examples/community/README.md

Lines changed: 46 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@ If a community doesn't work as expected, please open an issue and ping the autho
2525
| K-Diffusion Stable Diffusion | Run Stable Diffusion with any of [K-Diffusion's samplers](https://github.com/crowsonkb/k-diffusion/blob/master/k_diffusion/sampling.py) | [Stable Diffusion with K Diffusion](#stable-diffusion-with-k-diffusion) | - | [Patrick von Platen](https://github.com/patrickvonplaten/) |
2626
| Checkpoint Merger Pipeline | Diffusion Pipeline that enables merging of saved model checkpoints | [Checkpoint Merger Pipeline](#checkpoint-merger-pipeline) | - | [Naga Sai Abhinay Devarinti](https://github.com/Abhinay1997/) |
2727
Stable Diffusion v1.1-1.4 Comparison | Run all 4 model checkpoints for Stable Diffusion and compare their results together | [Stable Diffusion Comparison](#stable-diffusion-comparisons) | - | [Suvaditya Mukherjee](https://github.com/suvadityamuk) |
28+
MagicMix | Diffusion Pipeline for semantic mixing of an image and a text prompt | [MagicMix](#magic-mix) | - | [Partho Das](https://github.com/daspartho) |
2829

2930

3031

@@ -815,6 +816,50 @@ plt.title('Stable Diffusion v1.4')
815816
plt.axis('off')
816817

817818
plt.show()
819+
```
820+
821+
As a result, you can look at a grid of all 4 generated images being shown together, that captures a difference the advancement of the training between the 4 checkpoints.
822+
823+
### Magic Mix
824+
825+
Implementation of the [MagicMix: Semantic Mixing with Diffusion Models](https://arxiv.org/abs/2210.16056) paper. This is a Diffusion Pipeline for semantic mixing of an image and a text prompt to create a new concept while preserving the spatial layout and geometry of the subject in the image. The pipeline takes an image that provides the layout semantics and a prompt that provides the content semantics for the mixing process.
826+
827+
There are 3 parameters for the method-
828+
- `mix_factor`: It is the interpolation constant used in the layout generation phase. The greater the value of `mix_factor`, the greater the influence of the prompt on the layout generation process.
829+
- `kmax` and `kmin`: These determine the range for the layout and content generation process. A higher value of kmax results in loss of more information about the layout of the original image and a higher value of kmin results in more steps for content generation process.
830+
831+
Here is an example usage-
832+
818833
```python
834+
from diffusers import DiffusionPipeline, DDIMScheduler
835+
from PIL import Image
836+
837+
pipe = DiffusionPipeline.from_pretrained(
838+
"CompVis/stable-diffusion-v1-4",
839+
custom_pipeline="magic_mix",
840+
scheduler = DDIMScheduler.from_pretrained("CompVis/stable-diffusion-v1-4", subfolder="scheduler"),
841+
).to('cuda')
842+
843+
img = Image.open('phone.jpg')
844+
mix_img = pipe(
845+
img,
846+
prompt = 'bed',
847+
kmin = 0.3,
848+
kmax = 0.5,
849+
mix_factor = 0.5,
850+
)
851+
mix_img.save('phone_bed_mix.jpg')
852+
```
853+
The `mix_img` is a PIL image that can be saved locally or displayed directly in a google colab. Generated image is a mix of the layout semantics of the given image and the content semantics of the prompt.
854+
855+
E.g. the above script generates the following image:
856+
857+
`phone.jpg`
858+
859+
![206903102-34e79b9f-9ed2-4fac-bb38-82871343c655](https://user-images.githubusercontent.com/59410571/209578593-141467c7-d831-4792-8b9a-b17dc5e47816.jpg)
860+
861+
`phone_bed_mix.jpg`
862+
863+
![206903104-913a671d-ef53-4ae4-919d-64c3059c8f67](https://user-images.githubusercontent.com/59410571/209578602-70f323fa-05b7-4dd6-b055-e40683e37914.jpg)
819864

820-
As a result, you can look at a grid of all 4 generated images being shown together, that captures a difference the advancement of the training between the 4 checkpoints.
865+
For more example generations check out this [demo notebook](https://github.com/daspartho/MagicMix/blob/main/demo.ipynb).

examples/community/magic_mix.py

Lines changed: 152 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,152 @@
1+
from typing import Union
2+
3+
import torch
4+
5+
from diffusers import (
6+
AutoencoderKL,
7+
DDIMScheduler,
8+
DiffusionPipeline,
9+
LMSDiscreteScheduler,
10+
PNDMScheduler,
11+
UNet2DConditionModel,
12+
)
13+
from PIL import Image
14+
from torchvision import transforms as tfms
15+
from tqdm.auto import tqdm
16+
from transformers import CLIPTextModel, CLIPTokenizer
17+
18+
19+
class MagicMixPipeline(DiffusionPipeline):
20+
def __init__(
21+
self,
22+
vae: AutoencoderKL,
23+
text_encoder: CLIPTextModel,
24+
tokenizer: CLIPTokenizer,
25+
unet: UNet2DConditionModel,
26+
scheduler: Union[PNDMScheduler, LMSDiscreteScheduler, DDIMScheduler],
27+
):
28+
super().__init__()
29+
30+
self.register_modules(vae=vae, text_encoder=text_encoder, tokenizer=tokenizer, unet=unet, scheduler=scheduler)
31+
32+
# convert PIL image to latents
33+
def encode(self, img):
34+
with torch.no_grad():
35+
latent = self.vae.encode(tfms.ToTensor()(img).unsqueeze(0).to(self.device) * 2 - 1)
36+
latent = 0.18215 * latent.latent_dist.sample()
37+
return latent
38+
39+
# convert latents to PIL image
40+
def decode(self, latent):
41+
latent = (1 / 0.18215) * latent
42+
with torch.no_grad():
43+
img = self.vae.decode(latent).sample
44+
img = (img / 2 + 0.5).clamp(0, 1)
45+
img = img.detach().cpu().permute(0, 2, 3, 1).numpy()
46+
img = (img * 255).round().astype("uint8")
47+
return Image.fromarray(img[0])
48+
49+
# convert prompt into text embeddings, also unconditional embeddings
50+
def prep_text(self, prompt):
51+
text_input = self.tokenizer(
52+
prompt,
53+
padding="max_length",
54+
max_length=self.tokenizer.model_max_length,
55+
truncation=True,
56+
return_tensors="pt",
57+
)
58+
59+
text_embedding = self.text_encoder(text_input.input_ids.to(self.device))[0]
60+
61+
uncond_input = self.tokenizer(
62+
"",
63+
padding="max_length",
64+
max_length=self.tokenizer.model_max_length,
65+
truncation=True,
66+
return_tensors="pt",
67+
)
68+
69+
uncond_embedding = self.text_encoder(uncond_input.input_ids.to(self.device))[0]
70+
71+
return torch.cat([uncond_embedding, text_embedding])
72+
73+
def __call__(
74+
self,
75+
img: Image.Image,
76+
prompt: str,
77+
kmin: float = 0.3,
78+
kmax: float = 0.6,
79+
mix_factor: float = 0.5,
80+
seed: int = 42,
81+
steps: int = 50,
82+
guidance_scale: float = 7.5,
83+
) -> Image.Image:
84+
tmin = steps - int(kmin * steps)
85+
tmax = steps - int(kmax * steps)
86+
87+
text_embeddings = self.prep_text(prompt)
88+
89+
self.scheduler.set_timesteps(steps)
90+
91+
width, height = img.size
92+
encoded = self.encode(img)
93+
94+
torch.manual_seed(seed)
95+
noise = torch.randn(
96+
(1, self.unet.in_channels, height // 8, width // 8),
97+
).to(self.device)
98+
99+
latents = self.scheduler.add_noise(
100+
encoded,
101+
noise,
102+
timesteps=self.scheduler.timesteps[tmax],
103+
)
104+
105+
input = torch.cat([latents] * 2)
106+
107+
input = self.scheduler.scale_model_input(input, self.scheduler.timesteps[tmax])
108+
109+
with torch.no_grad():
110+
pred = self.unet(
111+
input,
112+
self.scheduler.timesteps[tmax],
113+
encoder_hidden_states=text_embeddings,
114+
).sample
115+
116+
pred_uncond, pred_text = pred.chunk(2)
117+
pred = pred_uncond + guidance_scale * (pred_text - pred_uncond)
118+
119+
latents = self.scheduler.step(pred, self.scheduler.timesteps[tmax], latents).prev_sample
120+
121+
for i, t in enumerate(tqdm(self.scheduler.timesteps)):
122+
if i > tmax:
123+
if i < tmin: # layout generation phase
124+
orig_latents = self.scheduler.add_noise(
125+
encoded,
126+
noise,
127+
timesteps=t,
128+
)
129+
130+
input = (mix_factor * latents) + (
131+
1 - mix_factor
132+
) * orig_latents # interpolating between layout noise and conditionally generated noise to preserve layout sematics
133+
input = torch.cat([input] * 2)
134+
135+
else: # content generation phase
136+
input = torch.cat([latents] * 2)
137+
138+
input = self.scheduler.scale_model_input(input, t)
139+
140+
with torch.no_grad():
141+
pred = self.unet(
142+
input,
143+
t,
144+
encoder_hidden_states=text_embeddings,
145+
).sample
146+
147+
pred_uncond, pred_text = pred.chunk(2)
148+
pred = pred_uncond + guidance_scale * (pred_text - pred_uncond)
149+
150+
latents = self.scheduler.step(pred, t, latents).prev_sample
151+
152+
return self.decode(latents)

0 commit comments

Comments
 (0)