diff --git a/examples/community/README.md b/examples/community/README.md index ddb0b8ce9389..a848f74f2a29 100644 --- a/examples/community/README.md +++ b/examples/community/README.md @@ -25,6 +25,7 @@ If a community doesn't work as expected, please open an issue and ping the autho | K-Diffusion Stable Diffusion | Run Stable Diffusion with any of [K-Diffusion's samplers](https://github.com/crowsonkb/k-diffusion/blob/master/k_diffusion/sampling.py) | [Stable Diffusion with K Diffusion](#stable-diffusion-with-k-diffusion) | - | [Patrick von Platen](https://github.com/patrickvonplaten/) | | Checkpoint Merger Pipeline | Diffusion Pipeline that enables merging of saved model checkpoints | [Checkpoint Merger Pipeline](#checkpoint-merger-pipeline) | - | [Naga Sai Abhinay Devarinti](https://github.com/Abhinay1997/) | Stable Diffusion v1.1-1.4 Comparison | Run all 4 model checkpoints for Stable Diffusion and compare their results together | [Stable Diffusion Comparison](#stable-diffusion-comparisons) | - | [Suvaditya Mukherjee](https://github.com/suvadityamuk) | +MagicMix | Diffusion Pipeline for semantic mixing of an image and a text prompt | [MagicMix](#magic-mix) | - | [Partho Das](https://github.com/daspartho) | @@ -815,6 +816,50 @@ plt.title('Stable Diffusion v1.4') plt.axis('off') plt.show() +``` + +As a result, you can look at a grid of all 4 generated images being shown together, that captures a difference the advancement of the training between the 4 checkpoints. + +### Magic Mix + +Implementation of the [MagicMix: Semantic Mixing with Diffusion Models](https://arxiv.org/abs/2210.16056) paper. This is a Diffusion Pipeline for semantic mixing of an image and a text prompt to create a new concept while preserving the spatial layout and geometry of the subject in the image. The pipeline takes an image that provides the layout semantics and a prompt that provides the content semantics for the mixing process. + +There are 3 parameters for the method- +- `mix_factor`: It is the interpolation constant used in the layout generation phase. The greater the value of `mix_factor`, the greater the influence of the prompt on the layout generation process. +- `kmax` and `kmin`: These determine the range for the layout and content generation process. A higher value of kmax results in loss of more information about the layout of the original image and a higher value of kmin results in more steps for content generation process. + +Here is an example usage- + ```python +from diffusers import DiffusionPipeline, DDIMScheduler +from PIL import Image + +pipe = DiffusionPipeline.from_pretrained( + "CompVis/stable-diffusion-v1-4", + custom_pipeline="magic_mix", + scheduler = DDIMScheduler.from_pretrained("CompVis/stable-diffusion-v1-4", subfolder="scheduler"), +).to('cuda') + +img = Image.open('phone.jpg') +mix_img = pipe( + img, + prompt = 'bed', + kmin = 0.3, + kmax = 0.5, + mix_factor = 0.5, + ) +mix_img.save('phone_bed_mix.jpg') +``` +The `mix_img` is a PIL image that can be saved locally or displayed directly in a google colab. Generated image is a mix of the layout semantics of the given image and the content semantics of the prompt. + +E.g. the above script generates the following image: + +`phone.jpg` + +![206903102-34e79b9f-9ed2-4fac-bb38-82871343c655](https://user-images.githubusercontent.com/59410571/209578593-141467c7-d831-4792-8b9a-b17dc5e47816.jpg) + +`phone_bed_mix.jpg` + +![206903104-913a671d-ef53-4ae4-919d-64c3059c8f67](https://user-images.githubusercontent.com/59410571/209578602-70f323fa-05b7-4dd6-b055-e40683e37914.jpg) -As a result, you can look at a grid of all 4 generated images being shown together, that captures a difference the advancement of the training between the 4 checkpoints. \ No newline at end of file +For more example generations check out this [demo notebook](https://github.com/daspartho/MagicMix/blob/main/demo.ipynb). diff --git a/examples/community/magic_mix.py b/examples/community/magic_mix.py new file mode 100644 index 000000000000..d67aec781c36 --- /dev/null +++ b/examples/community/magic_mix.py @@ -0,0 +1,152 @@ +from typing import Union + +import torch + +from diffusers import ( + AutoencoderKL, + DDIMScheduler, + DiffusionPipeline, + LMSDiscreteScheduler, + PNDMScheduler, + UNet2DConditionModel, +) +from PIL import Image +from torchvision import transforms as tfms +from tqdm.auto import tqdm +from transformers import CLIPTextModel, CLIPTokenizer + + +class MagicMixPipeline(DiffusionPipeline): + def __init__( + self, + vae: AutoencoderKL, + text_encoder: CLIPTextModel, + tokenizer: CLIPTokenizer, + unet: UNet2DConditionModel, + scheduler: Union[PNDMScheduler, LMSDiscreteScheduler, DDIMScheduler], + ): + super().__init__() + + self.register_modules(vae=vae, text_encoder=text_encoder, tokenizer=tokenizer, unet=unet, scheduler=scheduler) + + # convert PIL image to latents + def encode(self, img): + with torch.no_grad(): + latent = self.vae.encode(tfms.ToTensor()(img).unsqueeze(0).to(self.device) * 2 - 1) + latent = 0.18215 * latent.latent_dist.sample() + return latent + + # convert latents to PIL image + def decode(self, latent): + latent = (1 / 0.18215) * latent + with torch.no_grad(): + img = self.vae.decode(latent).sample + img = (img / 2 + 0.5).clamp(0, 1) + img = img.detach().cpu().permute(0, 2, 3, 1).numpy() + img = (img * 255).round().astype("uint8") + return Image.fromarray(img[0]) + + # convert prompt into text embeddings, also unconditional embeddings + def prep_text(self, prompt): + text_input = self.tokenizer( + prompt, + padding="max_length", + max_length=self.tokenizer.model_max_length, + truncation=True, + return_tensors="pt", + ) + + text_embedding = self.text_encoder(text_input.input_ids.to(self.device))[0] + + uncond_input = self.tokenizer( + "", + padding="max_length", + max_length=self.tokenizer.model_max_length, + truncation=True, + return_tensors="pt", + ) + + uncond_embedding = self.text_encoder(uncond_input.input_ids.to(self.device))[0] + + return torch.cat([uncond_embedding, text_embedding]) + + def __call__( + self, + img: Image.Image, + prompt: str, + kmin: float = 0.3, + kmax: float = 0.6, + mix_factor: float = 0.5, + seed: int = 42, + steps: int = 50, + guidance_scale: float = 7.5, + ) -> Image.Image: + tmin = steps - int(kmin * steps) + tmax = steps - int(kmax * steps) + + text_embeddings = self.prep_text(prompt) + + self.scheduler.set_timesteps(steps) + + width, height = img.size + encoded = self.encode(img) + + torch.manual_seed(seed) + noise = torch.randn( + (1, self.unet.in_channels, height // 8, width // 8), + ).to(self.device) + + latents = self.scheduler.add_noise( + encoded, + noise, + timesteps=self.scheduler.timesteps[tmax], + ) + + input = torch.cat([latents] * 2) + + input = self.scheduler.scale_model_input(input, self.scheduler.timesteps[tmax]) + + with torch.no_grad(): + pred = self.unet( + input, + self.scheduler.timesteps[tmax], + encoder_hidden_states=text_embeddings, + ).sample + + pred_uncond, pred_text = pred.chunk(2) + pred = pred_uncond + guidance_scale * (pred_text - pred_uncond) + + latents = self.scheduler.step(pred, self.scheduler.timesteps[tmax], latents).prev_sample + + for i, t in enumerate(tqdm(self.scheduler.timesteps)): + if i > tmax: + if i < tmin: # layout generation phase + orig_latents = self.scheduler.add_noise( + encoded, + noise, + timesteps=t, + ) + + input = (mix_factor * latents) + ( + 1 - mix_factor + ) * orig_latents # interpolating between layout noise and conditionally generated noise to preserve layout sematics + input = torch.cat([input] * 2) + + else: # content generation phase + input = torch.cat([latents] * 2) + + input = self.scheduler.scale_model_input(input, t) + + with torch.no_grad(): + pred = self.unet( + input, + t, + encoder_hidden_states=text_embeddings, + ).sample + + pred_uncond, pred_text = pred.chunk(2) + pred = pred_uncond + guidance_scale * (pred_text - pred_uncond) + + latents = self.scheduler.step(pred, t, latents).prev_sample + + return self.decode(latents)