From 9581d7a4a2023e92c3b26a7e08937529613e613f Mon Sep 17 00:00:00 2001 From: manuelbrack Date: Thu, 2 Feb 2023 18:32:48 +0100 Subject: [PATCH 01/14] Add semantic guidance pipeline --- src/diffusers/__init__.py | 1 + src/diffusers/pipelines/__init__.py | 1 + .../semantic_stable_diffusion/__init__.py | 32 + .../pipeline_semantic_stable_diffusion.py | 607 ++++++++++++++++++ .../dummy_torch_and_transformers_objects.py | 15 + 5 files changed, 656 insertions(+) create mode 100644 src/diffusers/pipelines/semantic_stable_diffusion/__init__.py create mode 100644 src/diffusers/pipelines/semantic_stable_diffusion/pipeline_semantic_stable_diffusion.py diff --git a/src/diffusers/__init__.py b/src/diffusers/__init__.py index f9803380121a..fefe0201cdb4 100644 --- a/src/diffusers/__init__.py +++ b/src/diffusers/__init__.py @@ -117,6 +117,7 @@ StableDiffusionInstructPix2PixPipeline, StableDiffusionPipeline, StableDiffusionPipelineSafe, + SemanticStableDiffusionPipeline, StableDiffusionUpscalePipeline, UnCLIPImageVariationPipeline, UnCLIPPipeline, diff --git a/src/diffusers/pipelines/__init__.py b/src/diffusers/pipelines/__init__.py index b69363f545e5..f8b8854b84d4 100644 --- a/src/diffusers/pipelines/__init__.py +++ b/src/diffusers/pipelines/__init__.py @@ -56,6 +56,7 @@ StableDiffusionUpscalePipeline, ) from .stable_diffusion_safe import StableDiffusionPipelineSafe + from .semantic_stable_diffusion import SemanticStableDiffusionPipeline from .unclip import UnCLIPImageVariationPipeline, UnCLIPPipeline from .versatile_diffusion import ( VersatileDiffusionDualGuidedPipeline, diff --git a/src/diffusers/pipelines/semantic_stable_diffusion/__init__.py b/src/diffusers/pipelines/semantic_stable_diffusion/__init__.py new file mode 100644 index 000000000000..a7716f58e609 --- /dev/null +++ b/src/diffusers/pipelines/semantic_stable_diffusion/__init__.py @@ -0,0 +1,32 @@ +from dataclasses import dataclass +from enum import Enum +from typing import List, Optional, Union + +import numpy as np + +import PIL +from PIL import Image + +from ...utils import BaseOutput, is_torch_available, is_transformers_available + + +@dataclass +class SemanticStableDiffusionPipelineOutput(BaseOutput): + """ + Output class for Stable Diffusion pipelines. + + Args: + images (`List[PIL.Image.Image]` or `np.ndarray`) + List of denoised PIL images of length `batch_size` or numpy array of shape `(batch_size, height, width, + num_channels)`. PIL images or numpy array present the denoised images of the diffusion pipeline. + nsfw_content_detected (`List[bool]`) + List of flags denoting whether the corresponding generated image likely represents "not-safe-for-work" + (nsfw) content, or `None` if safety checking could not be performed. + """ + + images: Union[List[PIL.Image.Image], np.ndarray] + nsfw_content_detected: Optional[List[bool]] + + +if is_transformers_available() and is_torch_available(): + from .pipeline_semantic_stable_diffusion import SemanticStableDiffusionPipeline diff --git a/src/diffusers/pipelines/semantic_stable_diffusion/pipeline_semantic_stable_diffusion.py b/src/diffusers/pipelines/semantic_stable_diffusion/pipeline_semantic_stable_diffusion.py new file mode 100644 index 000000000000..5de1b687c9ca --- /dev/null +++ b/src/diffusers/pipelines/semantic_stable_diffusion/pipeline_semantic_stable_diffusion.py @@ -0,0 +1,607 @@ +import inspect +import warnings +from itertools import repeat +from typing import Callable, List, Optional, Union + +import torch + +from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer + +from ...configuration_utils import FrozenDict +from ...models import AutoencoderKL, UNet2DConditionModel +from ...pipeline_utils import DiffusionPipeline +from ...pipelines.stable_diffusion.safety_checker import StableDiffusionSafetyChecker +from ...schedulers import DDIMScheduler, LMSDiscreteScheduler, PNDMScheduler +from ...utils import deprecate, logging +from . import SemanticStableDiffusionPipelineOutput + + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + + +class SemanticStableDiffusionPipeline(DiffusionPipeline): + r""" + Pipeline for text-to-image generation with latent editing. + + This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the + library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.) + + This model builds on the implementation of ['StableDiffusionPipeline'] + + Args: + vae ([`AutoencoderKL`]): + Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations. + text_encoder ([`CLIPTextModel`]): + Frozen text-encoder. Stable Diffusion uses the text portion of + [CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModel), specifically + the [clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14) variant. + tokenizer (`CLIPTokenizer`): + Tokenizer of class + [CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer). + unet ([`UNet2DConditionModel`]): Conditional U-Net architecture to denoise the encoded image latents. + scheduler ([`SchedulerMixin`]): + A scheduler to be used in combination with `unet` to denoise the encoded image latens. Can be one of + [`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`]. + safety_checker ([`Q16SafetyChecker`]): + Classification module that estimates whether generated images could be considered offensive or harmful. + Please, refer to the [model card](https://huggingface.co/CompVis/stable-diffusion-v1-4) for details. + feature_extractor ([`CLIPFeatureExtractor`]): + Model that extracts features from generated images to be used as inputs for the `safety_checker`. + """ + + def __init__( + self, + vae: AutoencoderKL, + text_encoder: CLIPTextModel, + tokenizer: CLIPTokenizer, + unet: UNet2DConditionModel, + scheduler: Union[DDIMScheduler, PNDMScheduler, LMSDiscreteScheduler], + safety_checker: StableDiffusionSafetyChecker, + feature_extractor: CLIPFeatureExtractor, + ): + super().__init__() + + if hasattr(scheduler.config, "steps_offset") and scheduler.config.steps_offset != 1: + deprecation_message = ( + f"The configuration file of this scheduler: {scheduler} is outdated. `steps_offset`" + f" should be set to 1 instead of {scheduler.config.steps_offset}. Please make sure " + "to update the config accordingly as leaving `steps_offset` might led to incorrect results" + " in future versions. If you have downloaded this checkpoint from the Hugging Face Hub," + " it would be very nice if you could open a Pull request for the `scheduler/scheduler_config.json`" + " file" + ) + deprecate("steps_offset!=1", "1.0.0", deprecation_message, standard_warn=False) + new_config = dict(scheduler.config) + new_config["steps_offset"] = 1 + scheduler._internal_dict = FrozenDict(new_config) + + if safety_checker is None: + warnings.warn( + f"You have disabled the safety checker for {self.__class__} by passing `safety_checker=None`. Ensure" + " that you abide to the conditions of the Stable Diffusion license and do not expose unfiltered" + " results in services or applications open to the public. Both the diffusers team and Hugging Face" + " strongly recommend to keep the safety filter enabled in all public facing circumstances, disabling" + " it only for use-cases that involve analyzing network behavior or auditing its results. For more" + " information, please have a look at https://github.com/huggingface/diffusers/pull/254 ." + ) + + self.register_modules( + vae=vae, + text_encoder=text_encoder, + tokenizer=tokenizer, + unet=unet, + scheduler=scheduler, + safety_checker=safety_checker, + feature_extractor=feature_extractor, + ) + + def enable_attention_slicing(self, slice_size: Optional[Union[str, int]] = "auto"): + r""" + Enable sliced attention computation. + + When this option is enabled, the attention module will split the input tensor in slices, to compute attention + in several steps. This is useful to save some memory in exchange for a small speed decrease. + + Args: + slice_size (`str` or `int`, *optional*, defaults to `"auto"`): + When `"auto"`, halves the input to the attention heads, so attention will be computed in two steps. If + a number is provided, uses as many slices as `attention_head_dim // slice_size`. In this case, + `attention_head_dim` must be a multiple of `slice_size`. + """ + if slice_size == "auto": + # half the attention head size is usually a good trade-off between + # speed and memory + slice_size = self.unet.config.attention_head_dim // 2 + self.unet.set_attention_slice(slice_size) + + def disable_attention_slicing(self): + r""" + Disable sliced attention computation. If `enable_attention_slicing` was previously invoked, this method will go + back to computing attention in one step. + """ + # set slice_size = `None` to disable `attention slicing` + self.enable_attention_slicing(None) + + @torch.no_grad() + def __call__( + self, + prompt: Union[str, List[str]], + height: int = 512, + width: int = 512, + num_inference_steps: int = 50, + guidance_scale: float = 7.5, + negative_prompt: Optional[Union[str, List[str]]] = None, + num_images_per_prompt: Optional[int] = 1, + eta: float = 0.0, + generator: Optional[torch.Generator] = None, + latents: Optional[torch.FloatTensor] = None, + output_type: Optional[str] = "pil", + return_dict: bool = True, + callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None, + callback_steps: Optional[int] = 1, + editing_prompt: Optional[Union[str, List[str]]] = None, + editing_prompt_prompt_embeddings=None, + reverse_editing_direction: Optional[Union[bool, List[bool]]] = False, + edit_guidance_scale: Optional[Union[float, List[float]]] = 500, + edit_warmup_steps: Optional[Union[int, List[int]]] = 10, + edit_cooldown_steps: Optional[Union[int, List[int]]] = None, + edit_threshold: Optional[Union[float, List[float]]] = None, + edit_momentum_scale: Optional[float] = 0.1, + edit_mom_beta: Optional[float] = 0.4, + edit_weights: Optional[List[float]] = None, + sem_guidance: Optional[List[torch.FloatTensor]] = None, + **kwargs, + ): + r""" + Function invoked when calling the pipeline for generation. + + Args: + prompt (`str` or `List[str]`): + The prompt or prompts to guide the image generation. + height (`int`, *optional*, defaults to 512): + The height in pixels of the generated image. + width (`int`, *optional*, defaults to 512): + The width in pixels of the generated image. + num_inference_steps (`int`, *optional*, defaults to 50): + The number of denoising steps. More denoising steps usually lead to a higher quality image at the + expense of slower inference. + guidance_scale (`float`, *optional*, defaults to 7.5): + Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598). + `guidance_scale` is defined as `w` of equation 2. of [Imagen + Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale > + 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`, + usually at the expense of lower image quality. + negative_prompt (`str` or `List[str]`, *optional*): + The prompt or prompts not to guide the image generation. Ignored when not using guidance (i.e., ignored + if `guidance_scale` is less than `1`). + num_images_per_prompt (`int`, *optional*, defaults to 1): + The number of images to generate per prompt. + eta (`float`, *optional*, defaults to 0.0): + Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to + [`schedulers.DDIMScheduler`], will be ignored for others. + generator (`torch.Generator`, *optional*): + A [torch generator](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make generation + deterministic. + latents (`torch.FloatTensor`, *optional*): + Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image + generation. Can be used to tweak the same generation with different prompts. If not provided, a latents + tensor will ge generated by sampling using the supplied random `generator`. + output_type (`str`, *optional*, defaults to `"pil"`): + The output format of the generate image. Choose between + [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a + plain tuple. + callback (`Callable`, *optional*): + A function that will be called every `callback_steps` steps during inference. The function will be + called with the following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`. + callback_steps (`int`, *optional*, defaults to 1): + The frequency at which the `callback` function will be called. If not specified, the callback will be + called at every step. + editing_prompt (`str` or `List[str]`, *optional*): + The prompt or prompts to use for Semantic guidance. Semantic guidance is disabled by setting + `editing_prompt = None`. Guidance direction of prompt should be specified via + `reverse_editing_direction`. + reverse_editing_direction (`bool` or `List[bool]`, *optional*): + Whether the corresponding prompt in `editing_prompt` should be increased or decreased. + edit_guidance_scale (`float` or `List[float]`, *optional*, defaults to 5): + Guidance scale for semantic guidance. If provided as list values should correspond to `editing_prompt`. + edit_warmup_steps (`float` or `List[float]`, *optional*, defaults to 10): + Number of diffusion steps (for each prompt) for which semantic guidance will not be applied. Momentum + will still be calculated for those steps and applied once all warmup periods are over. + edit_cooldown_steps (`float` or `List[float]`, *optional*, defaults to `None`): + Number of diffusion steps (for each prompt) after which semantic guidance will no longer be applied. + edit_threshold (`float` or `List[float]`, *optional*, defaults to `None`): + Threshold of semantic guidance. + edit_momentum_scale (`float`, *optional*, defaults to 0.1): + Scale of the momentum to be added to the semantic guidance at each diffusion step. If set to 0.0 + momentum will be disabled. Momentum is already built up during warmup, i.e. for diffusion steps smaller + than `sld_warmup_steps`. Momentum will only be added to latent guidance once all warmup periods are + finished. + edit_mom_beta (`float`, *optional*, defaults to 0.4): + Defines how semantic guidance momentum builds up. `edit_mom_beta` indicates how much of the previous + momentum will be kept. Momentum is already built up during warmup, i.e. for diffusion steps smaller + than `edit_warmup_steps`. + edit_weights (`List[float]`, *optional*, defaults to `None`): + Indicates how much each individual concept should influence the overall guidance. If no weights are + provided all concepts are applied equally. + sem_guidance (`List[torch.FloatTensor]`, *optional*): + List of pre-generated guidance vectors to be applied at generation. Length of the list has to + correspond to `num_inference_steps`. + + Returns: + [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`: + [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] if `return_dict` is True, otherwise a `tuple. + When returning a tuple, the first element is a list with the generated images, and the second element is a + list of `bool`s denoting whether the corresponding generated image likely represents "not-safe-for-work" + (nsfw) content, according to the `safety_checker`. + """ + + if isinstance(prompt, str): + batch_size = 1 + elif isinstance(prompt, list): + batch_size = len(prompt) + else: + raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") + + if height % 8 != 0 or width % 8 != 0: + raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.") + + if (callback_steps is None) or ( + callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0) + ): + raise ValueError( + f"`callback_steps` has to be a positive integer but is {callback_steps} of type" + f" {type(callback_steps)}." + ) + + if editing_prompt: + enable_edit_guidance = True + if isinstance(editing_prompt, str): + editing_prompt = [editing_prompt] + enabled_editing_prompts = len(editing_prompt) + elif editing_prompt_prompt_embeddings is not None: + enable_edit_guidance = True + enabled_editing_prompts = editing_prompt_prompt_embeddings.shape[0] + else: + enabled_editing_prompts = 0 + enable_edit_guidance = False + + # get prompt text embeddings + text_inputs = self.tokenizer( + prompt, + padding="max_length", + max_length=self.tokenizer.model_max_length, + return_tensors="pt", + ) + text_input_ids = text_inputs.input_ids + + if text_input_ids.shape[-1] > self.tokenizer.model_max_length: + removed_text = self.tokenizer.batch_decode(text_input_ids[:, self.tokenizer.model_max_length :]) + logger.warning( + "The following part of your input was truncated because CLIP can only handle sequences up to" + f" {self.tokenizer.model_max_length} tokens: {removed_text}" + ) + text_input_ids = text_input_ids[:, : self.tokenizer.model_max_length] + text_embeddings = self.text_encoder(text_input_ids.to(self.device))[0] + + # duplicate text embeddings for each generation per prompt, using mps friendly method + bs_embed, seq_len, _ = text_embeddings.shape + text_embeddings = text_embeddings.repeat(1, num_images_per_prompt, 1) + text_embeddings = text_embeddings.view(bs_embed * num_images_per_prompt, seq_len, -1) + + if enable_edit_guidance: + # get safety text embeddings + if editing_prompt_prompt_embeddings is None: + edit_concepts_input = self.tokenizer( + [x for item in editing_prompt for x in repeat(item, batch_size)], + padding="max_length", + max_length=self.tokenizer.model_max_length, + return_tensors="pt", + ) + + edit_concepts_input_ids = edit_concepts_input.input_ids + + if edit_concepts_input_ids.shape[-1] > self.tokenizer.model_max_length: + removed_text = self.tokenizer.batch_decode( + edit_concepts_input_ids[:, self.tokenizer.model_max_length :] + ) + logger.warning( + "The following part of your input was truncated because CLIP can only handle sequences up to" + f" {self.tokenizer.model_max_length} tokens: {removed_text}" + ) + edit_concepts_input_ids = edit_concepts_input_ids[:, : self.tokenizer.model_max_length] + edit_concepts = self.text_encoder(edit_concepts_input_ids.to(self.device))[0] + else: + edit_concepts = editing_prompt_prompt_embeddings.to(self.device).repeat(batch_size, 1, 1) + + # duplicate text embeddings for each generation per prompt, using mps friendly method + bs_embed_edit, seq_len_edit, _ = edit_concepts.shape + edit_concepts = edit_concepts.repeat(1, num_images_per_prompt, 1) + edit_concepts = edit_concepts.view(bs_embed_edit * num_images_per_prompt, seq_len_edit, -1) + + # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) + # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1` + # corresponds to doing no classifier free guidance. + do_classifier_free_guidance = guidance_scale > 1.0 + # get unconditional embeddings for classifier free guidance + + if do_classifier_free_guidance: + uncond_tokens: List[str] + if negative_prompt is None: + uncond_tokens = [""] + elif type(prompt) is not type(negative_prompt): + raise TypeError( + f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !=" + f" {type(prompt)}." + ) + elif isinstance(negative_prompt, str): + uncond_tokens = [negative_prompt] + elif batch_size != len(negative_prompt): + raise ValueError( + f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:" + f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches" + " the batch size of `prompt`." + ) + else: + uncond_tokens = negative_prompt + + max_length = text_input_ids.shape[-1] + uncond_input = self.tokenizer( + uncond_tokens, + padding="max_length", + max_length=max_length, + truncation=True, + return_tensors="pt", + ) + uncond_embeddings = self.text_encoder(uncond_input.input_ids.to(self.device))[0] + + # duplicate unconditional embeddings for each generation per prompt, using mps friendly method + seq_len = uncond_embeddings.shape[1] + uncond_embeddings = uncond_embeddings.repeat(batch_size, num_images_per_prompt, 1) + uncond_embeddings = uncond_embeddings.view(batch_size * num_images_per_prompt, seq_len, -1) + + # For classifier free guidance, we need to do two forward passes. + # Here we concatenate the unconditional and text embeddings into a single batch + # to avoid doing two forward passes + if enable_edit_guidance: + text_embeddings = torch.cat([uncond_embeddings, text_embeddings, edit_concepts]) + else: + text_embeddings = torch.cat([uncond_embeddings, text_embeddings]) + # get the initial random noise unless the user supplied it + + # Unlike in other pipelines, latents need to be generated in the target device + # for 1-to-1 results reproducibility with the CompVis implementation. + # However this currently doesn't work in `mps`. + latents_shape = (batch_size * num_images_per_prompt, self.unet.in_channels, height // 8, width // 8) + latents_dtype = text_embeddings.dtype + if latents is None: + if self.device.type == "mps": + # randn does not exist on mps + latents = torch.randn(latents_shape, generator=generator, device="cpu", dtype=latents_dtype).to( + self.device + ) + else: + latents = torch.randn(latents_shape, generator=generator, device=self.device, dtype=latents_dtype) + else: + if latents.shape != latents_shape: + raise ValueError(f"Unexpected latents shape, got {latents.shape}, expected {latents_shape}") + latents = latents.to(self.device) + + # set timesteps + self.scheduler.set_timesteps(num_inference_steps) + + # Some schedulers like PNDM have timesteps as arrays + # It's more optimized to move all timesteps to correct device beforehand + timesteps_tensor = self.scheduler.timesteps.to(self.device) + + # scale the initial noise by the standard deviation required by the scheduler + latents = latents * self.scheduler.init_noise_sigma + + # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature + # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers. + # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502 + # and should be between [0, 1] + accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys()) + extra_step_kwargs = {} + if accepts_eta: + extra_step_kwargs["eta"] = eta + + # Initialize edit_momentum to None + edit_momentum = None + + self.uncond_estimates = None + self.text_estimates = None + self.edit_estimates = None + self.sem_guidance = None + + for i, t in enumerate(self.progress_bar(timesteps_tensor)): + # expand the latents if we are doing classifier free guidance + latent_model_input = ( + torch.cat([latents] * (2 + enabled_editing_prompts)) if do_classifier_free_guidance else latents + ) + latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) + + # predict the noise residual + noise_pred = self.unet(latent_model_input, t, encoder_hidden_states=text_embeddings).sample + + # perform guidance + if do_classifier_free_guidance: + noise_pred_out = noise_pred.chunk(2 + enabled_editing_prompts) # [b,4, 64, 64] + noise_pred_uncond, noise_pred_text = noise_pred_out[0], noise_pred_out[1] + noise_pred_edit_concepts = noise_pred_out[2:] + + # default text guidance + noise_guidance = guidance_scale * (noise_pred_text - noise_pred_uncond) + # noise_guidance = (noise_pred_text - noise_pred_edit_concepts[0]) + + if self.uncond_estimates is None: + self.uncond_estimates = torch.zeros((num_inference_steps + 1, *noise_pred_uncond.shape)) + self.uncond_estimates[i] = noise_pred_uncond.detach().cpu() + + if self.text_estimates is None: + self.text_estimates = torch.zeros((num_inference_steps + 1, *noise_pred_text.shape)) + self.text_estimates[i] = noise_pred_text.detach().cpu() + + if self.edit_estimates is None and enable_edit_guidance: + self.edit_estimates = torch.zeros( + (num_inference_steps + 1, len(noise_pred_edit_concepts), *noise_pred_edit_concepts[0].shape) + ) + + if self.sem_guidance is None: + self.sem_guidance = torch.zeros((num_inference_steps + 1, *noise_pred_text.shape)) + + if edit_momentum is None: + edit_momentum = torch.zeros_like(noise_guidance) + + if enable_edit_guidance: + concept_weights = torch.zeros( + (len(noise_pred_edit_concepts), noise_guidance.shape[0]), device=self.device + ) + noise_guidance_edit = torch.zeros( + (len(noise_pred_edit_concepts), *noise_guidance.shape), device=self.device + ) + # noise_guidance_edit = torch.zeros_like(noise_guidance) + warmup_inds = [] + for c, noise_pred_edit_concept in enumerate(noise_pred_edit_concepts): + self.edit_estimates[i, c] = noise_pred_edit_concept + if isinstance(edit_guidance_scale, list): + edit_guidance_scale_c = edit_guidance_scale[c] + else: + edit_guidance_scale_c = edit_guidance_scale + + if isinstance(edit_threshold, list): + edit_threshold_c = edit_threshold[c] + else: + edit_threshold_c = edit_threshold + if isinstance(reverse_editing_direction, list): + reverse_editing_direction_c = reverse_editing_direction[c] + else: + reverse_editing_direction_c = reverse_editing_direction + if edit_weights: + edit_weight_c = edit_weights[c] + else: + edit_weight_c = 1.0 + if isinstance(edit_warmup_steps, list): + edit_warmup_steps_c = edit_warmup_steps[c] + else: + edit_warmup_steps_c = edit_warmup_steps + + if isinstance(edit_cooldown_steps, list): + edit_cooldown_steps_c = edit_cooldown_steps[c] + elif edit_cooldown_steps is None: + edit_cooldown_steps_c = i + 1 + else: + edit_cooldown_steps_c = edit_cooldown_steps + if i >= edit_warmup_steps_c: + warmup_inds.append(c) + if i >= edit_cooldown_steps_c: + noise_guidance_edit[c, :, :, :, :] = torch.zeros_like(noise_pred_edit_concept) + continue + + noise_guidance_edit_tmp = noise_pred_edit_concept - noise_pred_uncond + # tmp_weights = (noise_pred_text - noise_pred_edit_concept).sum(dim=(1, 2, 3)) + tmp_weights = (noise_guidance - noise_pred_edit_concept).sum(dim=(1, 2, 3)) + + tmp_weights = torch.full_like(tmp_weights, edit_weight_c) # * (1 / enabled_editing_prompts) + if reverse_editing_direction_c: + noise_guidance_edit_tmp = noise_guidance_edit_tmp * -1 + concept_weights[c, :] = tmp_weights + + noise_guidance_edit_tmp = noise_guidance_edit_tmp * edit_guidance_scale_c + tmp = torch.quantile( + torch.abs(noise_guidance_edit_tmp).flatten(start_dim=2), + edit_threshold_c, + dim=2, + keepdim=False, + ) + noise_guidance_edit_tmp = torch.where( + torch.abs(noise_guidance_edit_tmp) >= tmp[:, :, None, None], + noise_guidance_edit_tmp, + torch.zeros_like(noise_guidance_edit_tmp), + ) + noise_guidance_edit[c, :, :, :, :] = noise_guidance_edit_tmp + + # noise_guidance_edit = noise_guidance_edit + noise_guidance_edit_tmp + + warmup_inds = torch.tensor(warmup_inds).to(self.device) + if len(noise_pred_edit_concepts) > warmup_inds.shape[0] > 0: + concept_weights = concept_weights.to("cpu") # Offload to cpu + noise_guidance_edit = noise_guidance_edit.to("cpu") + + concept_weights_tmp = torch.index_select(concept_weights.to(self.device), 0, warmup_inds) + concept_weights_tmp = torch.where( + concept_weights_tmp < 0, torch.zeros_like(concept_weights_tmp), concept_weights_tmp + ) + concept_weights_tmp = concept_weights_tmp / concept_weights_tmp.sum(dim=0) + # concept_weights_tmp = torch.nan_to_num(concept_weights_tmp) + + noise_guidance_edit_tmp = torch.index_select( + noise_guidance_edit.to(self.device), 0, warmup_inds + ) + noise_guidance_edit_tmp = torch.einsum( + "cb,cbijk->bijk", concept_weights_tmp, noise_guidance_edit_tmp + ) + noise_guidance_edit_tmp = noise_guidance_edit_tmp + noise_guidance = noise_guidance + noise_guidance_edit_tmp + + self.sem_guidance[i] = noise_guidance_edit_tmp.detach().cpu() + + del noise_guidance_edit_tmp + del concept_weights_tmp + concept_weights = concept_weights.to(self.device) + noise_guidance_edit = noise_guidance_edit.to(self.device) + + concept_weights = torch.where( + concept_weights < 0, torch.zeros_like(concept_weights), concept_weights + ) + + concept_weights = torch.nan_to_num(concept_weights) + noise_guidance_edit = torch.einsum("cb,cbijk->bijk", concept_weights, noise_guidance_edit) + + noise_guidance_edit = noise_guidance_edit + edit_momentum_scale * edit_momentum + + edit_momentum = edit_mom_beta * edit_momentum + (1 - edit_mom_beta) * noise_guidance_edit + + if warmup_inds.shape[0] == len(noise_pred_edit_concepts): + noise_guidance = noise_guidance + noise_guidance_edit + self.sem_guidance[i] = noise_guidance_edit.detach().cpu() + + if sem_guidance is not None: + edit_guidance = sem_guidance[i].to(self.device) + noise_guidance = noise_guidance + edit_guidance + + noise_pred = noise_pred_uncond + noise_guidance + + # compute the previous noisy sample x_t -> x_t-1 + latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample + + # call the callback, if provided + if callback is not None and i % callback_steps == 0: + callback(i, t, latents) + + latents = 1 / 0.18215 * latents + image = self.vae.decode(latents).sample + + image = (image / 2 + 0.5).clamp(0, 1) + + # we always cast to float32 as this does not cause significant overhead and is compatible with bfloa16 + image = image.cpu().permute(0, 2, 3, 1).float().numpy() + + if self.safety_checker is not None: + safety_checker_input = self.feature_extractor(self.numpy_to_pil(image), return_tensors="pt").to( + self.device + ) + image, has_nsfw_concept = self.safety_checker( + images=image, clip_input=safety_checker_input.pixel_values.to(text_embeddings.dtype) + ) + else: + has_nsfw_concept = None + + if output_type == "pil": + image = self.numpy_to_pil(image) + + if not return_dict: + return (image, has_nsfw_concept) + + return SemanticStableDiffusionPipelineOutput(images=image, nsfw_content_detected=has_nsfw_concept) diff --git a/src/diffusers/utils/dummy_torch_and_transformers_objects.py b/src/diffusers/utils/dummy_torch_and_transformers_objects.py index 14aee55e5fca..f1518125b310 100644 --- a/src/diffusers/utils/dummy_torch_and_transformers_objects.py +++ b/src/diffusers/utils/dummy_torch_and_transformers_objects.py @@ -79,6 +79,21 @@ def from_pretrained(cls, *args, **kwargs): requires_backends(cls, ["torch", "transformers"]) +class SemanticStableDiffusionPipeline(metaclass=DummyObject): + _backends = ["torch", "transformers"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch", "transformers"]) + + @classmethod + def from_config(cls, *args, **kwargs): + requires_backends(cls, ["torch", "transformers"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["torch", "transformers"]) + + class StableDiffusionDepth2ImgPipeline(metaclass=DummyObject): _backends = ["torch", "transformers"] From 9eaa2b0d397c51c0ca2544d3208bfcd223fe91dd Mon Sep 17 00:00:00 2001 From: manuelbrack Date: Thu, 2 Feb 2023 18:46:39 +0100 Subject: [PATCH 02/14] Fix style --- src/diffusers/__init__.py | 2 +- src/diffusers/pipelines/__init__.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/src/diffusers/__init__.py b/src/diffusers/__init__.py index fefe0201cdb4..88493b82d8a3 100644 --- a/src/diffusers/__init__.py +++ b/src/diffusers/__init__.py @@ -109,6 +109,7 @@ CycleDiffusionPipeline, LDMTextToImagePipeline, PaintByExamplePipeline, + SemanticStableDiffusionPipeline, StableDiffusionDepth2ImgPipeline, StableDiffusionImageVariationPipeline, StableDiffusionImg2ImgPipeline, @@ -117,7 +118,6 @@ StableDiffusionInstructPix2PixPipeline, StableDiffusionPipeline, StableDiffusionPipelineSafe, - SemanticStableDiffusionPipeline, StableDiffusionUpscalePipeline, UnCLIPImageVariationPipeline, UnCLIPPipeline, diff --git a/src/diffusers/pipelines/__init__.py b/src/diffusers/pipelines/__init__.py index f8b8854b84d4..60d9ac6bc178 100644 --- a/src/diffusers/pipelines/__init__.py +++ b/src/diffusers/pipelines/__init__.py @@ -44,6 +44,7 @@ from .alt_diffusion import AltDiffusionImg2ImgPipeline, AltDiffusionPipeline from .latent_diffusion import LDMTextToImagePipeline from .paint_by_example import PaintByExamplePipeline + from .semantic_stable_diffusion import SemanticStableDiffusionPipeline from .stable_diffusion import ( CycleDiffusionPipeline, StableDiffusionDepth2ImgPipeline, @@ -56,7 +57,6 @@ StableDiffusionUpscalePipeline, ) from .stable_diffusion_safe import StableDiffusionPipelineSafe - from .semantic_stable_diffusion import SemanticStableDiffusionPipeline from .unclip import UnCLIPImageVariationPipeline, UnCLIPPipeline from .versatile_diffusion import ( VersatileDiffusionDualGuidedPipeline, From 6a6620dcbe48da3780c81a966147d8eb865c7744 Mon Sep 17 00:00:00 2001 From: manuelbrack Date: Sun, 5 Feb 2023 20:46:01 +0100 Subject: [PATCH 03/14] Refactor Pipeline --- .../pipeline_semantic_stable_diffusion.py | 95 +++++++++++-------- 1 file changed, 57 insertions(+), 38 deletions(-) diff --git a/src/diffusers/pipelines/semantic_stable_diffusion/pipeline_semantic_stable_diffusion.py b/src/diffusers/pipelines/semantic_stable_diffusion/pipeline_semantic_stable_diffusion.py index 5de1b687c9ca..77e6a0639d6a 100644 --- a/src/diffusers/pipelines/semantic_stable_diffusion/pipeline_semantic_stable_diffusion.py +++ b/src/diffusers/pipelines/semantic_stable_diffusion/pipeline_semantic_stable_diffusion.py @@ -2,6 +2,7 @@ import warnings from itertools import repeat from typing import Callable, List, Optional, Union +from packaging import version import torch @@ -11,7 +12,7 @@ from ...models import AutoencoderKL, UNet2DConditionModel from ...pipeline_utils import DiffusionPipeline from ...pipelines.stable_diffusion.safety_checker import StableDiffusionSafetyChecker -from ...schedulers import DDIMScheduler, LMSDiscreteScheduler, PNDMScheduler +from ...schedulers import DDIMScheduler, LMSDiscreteScheduler, PNDMScheduler, KarrasDiffusionSchedulers from ...utils import deprecate, logging from . import SemanticStableDiffusionPipelineOutput @@ -49,15 +50,18 @@ class SemanticStableDiffusionPipeline(DiffusionPipeline): Model that extracts features from generated images to be used as inputs for the `safety_checker`. """ + _optional_components = ["safety_checker", "feature_extractor"] + def __init__( - self, - vae: AutoencoderKL, - text_encoder: CLIPTextModel, - tokenizer: CLIPTokenizer, - unet: UNet2DConditionModel, - scheduler: Union[DDIMScheduler, PNDMScheduler, LMSDiscreteScheduler], - safety_checker: StableDiffusionSafetyChecker, - feature_extractor: CLIPFeatureExtractor, + self, + vae: AutoencoderKL, + text_encoder: CLIPTextModel, + tokenizer: CLIPTokenizer, + unet: UNet2DConditionModel, + scheduler: KarrasDiffusionSchedulers, + safety_checker: StableDiffusionSafetyChecker, + feature_extractor: CLIPFeatureExtractor, + requires_safety_checker: bool = True, ): super().__init__() @@ -75,8 +79,21 @@ def __init__( new_config["steps_offset"] = 1 scheduler._internal_dict = FrozenDict(new_config) - if safety_checker is None: - warnings.warn( + if hasattr(scheduler.config, "clip_sample") and scheduler.config.clip_sample is True: + deprecation_message = ( + f"The configuration file of this scheduler: {scheduler} has not set the configuration `clip_sample`." + " `clip_sample` should be set to False in the configuration file. Please make sure to update the" + " config accordingly as not setting `clip_sample` in the config might lead to incorrect results in" + " future versions. If you have downloaded this checkpoint from the Hugging Face Hub, it would be very" + " nice if you could open a Pull request for the `scheduler/scheduler_config.json` file" + ) + deprecate("clip_sample not set", "1.0.0", deprecation_message, standard_warn=False) + new_config = dict(scheduler.config) + new_config["clip_sample"] = False + scheduler._internal_dict = FrozenDict(new_config) + + if safety_checker is None and requires_safety_checker: + logger.warning( f"You have disabled the safety checker for {self.__class__} by passing `safety_checker=None`. Ensure" " that you abide to the conditions of the Stable Diffusion license and do not expose unfiltered" " results in services or applications open to the public. Both the diffusers team and Hugging Face" @@ -85,6 +102,33 @@ def __init__( " information, please have a look at https://github.com/huggingface/diffusers/pull/254 ." ) + if safety_checker is not None and feature_extractor is None: + raise ValueError( + "Make sure to define a feature extractor when loading {self.__class__} if you want to use the safety" + " checker. If you do not want to use the safety checker, you can pass `'safety_checker=None'` instead." + ) + + is_unet_version_less_0_9_0 = hasattr(unet.config, "_diffusers_version") and version.parse( + version.parse(unet.config._diffusers_version).base_version + ) < version.parse("0.9.0.dev0") + is_unet_sample_size_less_64 = hasattr(unet.config, "sample_size") and unet.config.sample_size < 64 + if is_unet_version_less_0_9_0 and is_unet_sample_size_less_64: + deprecation_message = ( + "The configuration file of the unet has set the default `sample_size` to smaller than" + " 64 which seems highly unlikely. If your checkpoint is a fine-tuned version of any of the" + " following: \n- CompVis/stable-diffusion-v1-4 \n- CompVis/stable-diffusion-v1-3 \n-" + " CompVis/stable-diffusion-v1-2 \n- CompVis/stable-diffusion-v1-1 \n- runwayml/stable-diffusion-v1-5" + " \n- runwayml/stable-diffusion-inpainting \n you should change 'sample_size' to 64 in the" + " configuration file. Please make sure to update the config accordingly as leaving `sample_size=32`" + " in the config might lead to incorrect results in future versions. If you have downloaded this" + " checkpoint from the Hugging Face Hub, it would be very nice if you could open a Pull request for" + " the `unet/config.json` file" + ) + deprecate("sample_size<64", "1.0.0", deprecation_message, standard_warn=False) + new_config = dict(unet.config) + new_config["sample_size"] = 64 + unet._internal_dict = FrozenDict(new_config) + self.register_modules( vae=vae, text_encoder=text_encoder, @@ -94,33 +138,8 @@ def __init__( safety_checker=safety_checker, feature_extractor=feature_extractor, ) - - def enable_attention_slicing(self, slice_size: Optional[Union[str, int]] = "auto"): - r""" - Enable sliced attention computation. - - When this option is enabled, the attention module will split the input tensor in slices, to compute attention - in several steps. This is useful to save some memory in exchange for a small speed decrease. - - Args: - slice_size (`str` or `int`, *optional*, defaults to `"auto"`): - When `"auto"`, halves the input to the attention heads, so attention will be computed in two steps. If - a number is provided, uses as many slices as `attention_head_dim // slice_size`. In this case, - `attention_head_dim` must be a multiple of `slice_size`. - """ - if slice_size == "auto": - # half the attention head size is usually a good trade-off between - # speed and memory - slice_size = self.unet.config.attention_head_dim // 2 - self.unet.set_attention_slice(slice_size) - - def disable_attention_slicing(self): - r""" - Disable sliced attention computation. If `enable_attention_slicing` was previously invoked, this method will go - back to computing attention in one step. - """ - # set slice_size = `None` to disable `attention slicing` - self.enable_attention_slicing(None) + self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) + self.register_to_config(requires_safety_checker=requires_safety_checker) @torch.no_grad() def __call__( From 5a1a1dd1224ffbeed4aa9550ed355067e1c09f98 Mon Sep 17 00:00:00 2001 From: manuelbrack Date: Sun, 5 Feb 2023 21:10:52 +0100 Subject: [PATCH 04/14] Pipeline documentation --- .../pipeline_semantic_stable_diffusion.py | 48 +++++++++++++++---- 1 file changed, 38 insertions(+), 10 deletions(-) diff --git a/src/diffusers/pipelines/semantic_stable_diffusion/pipeline_semantic_stable_diffusion.py b/src/diffusers/pipelines/semantic_stable_diffusion/pipeline_semantic_stable_diffusion.py index 77e6a0639d6a..3e2a9e065516 100644 --- a/src/diffusers/pipelines/semantic_stable_diffusion/pipeline_semantic_stable_diffusion.py +++ b/src/diffusers/pipelines/semantic_stable_diffusion/pipeline_semantic_stable_diffusion.py @@ -18,7 +18,31 @@ logger = logging.get_logger(__name__) # pylint: disable=invalid-name - +EXAMPLE_DOC_STRING = """ + Examples: + ```py + >>> import torch + >>> from diffusers import SemanticStableDiffusionPipeline + + >>> pipe = SemanticStableDiffusionPipeline.from_pretrained("runwayml/stable-diffusion-v1-5", torch_dtype=torch.float16) + >>> pipe = pipe.to("cuda") + + >>> out = pipe(prompt='a photo of the face of a woman', generator=gen, num_images_per_prompt=1, guidance_scale=7, + >>> editing_prompt=['smiling, smile', # Concepts to apply + 'glasses, wearing glasses', + 'curls, wavy hair, curly hair', + 'beard, full beard, mustache'], + >>> reverse_editing_direction=[False, False, False, False], # Direction of guidance i.e. increase all concepts + >>> edit_warmup_steps=[10, 10, 10,10], # Warmup period for each concept + >>> edit_guidance_scale=[4, 5, 5, 5.4], # Guidance scale for each concept + >>> edit_threshold=[0.99, 0.975, 0.925, 0.96], # Threshold for each concept. Threshold equals the percentile of the latent space that will be discarded. I.e. threshold=0.99 uses 1% of the latent dimensions + >>> edit_momentum_scale=0.3, # Momentum scale that will be added to the latent guidance + >>> edit_beta1=0.6, # Momentum beta + >>> edit_weights=[1,1,1,1,1] # Weights of the individual concepts against each other + >>> ) + >>> image = out.images[0] + ``` +""" class SemanticStableDiffusionPipeline(DiffusionPipeline): r""" @@ -142,6 +166,7 @@ def __init__( self.register_to_config(requires_safety_checker=requires_safety_checker) @torch.no_grad() + @replace_example_docstring(EXAMPLE_DOC_STRING) def __call__( self, prompt: Union[str, List[str]], @@ -159,12 +184,12 @@ def __call__( callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None, callback_steps: Optional[int] = 1, editing_prompt: Optional[Union[str, List[str]]] = None, - editing_prompt_prompt_embeddings=None, + editing_prompt_embeddings: Optional[torch.FloatTensor]=None, reverse_editing_direction: Optional[Union[bool, List[bool]]] = False, - edit_guidance_scale: Optional[Union[float, List[float]]] = 500, + edit_guidance_scale: Optional[Union[float, List[float]]] = 5, edit_warmup_steps: Optional[Union[int, List[int]]] = 10, edit_cooldown_steps: Optional[Union[int, List[int]]] = None, - edit_threshold: Optional[Union[float, List[float]]] = None, + edit_threshold: Optional[Union[float, List[float]]] = 0.9, edit_momentum_scale: Optional[float] = 0.1, edit_mom_beta: Optional[float] = 0.4, edit_weights: Optional[List[float]] = None, @@ -221,7 +246,10 @@ def __call__( The prompt or prompts to use for Semantic guidance. Semantic guidance is disabled by setting `editing_prompt = None`. Guidance direction of prompt should be specified via `reverse_editing_direction`. - reverse_editing_direction (`bool` or `List[bool]`, *optional*): + editing_prompt_embeddings (`torch.FloatTensor`, *optional*): + Pre-computed embeddings to use for semantic guidance. Guidance direction of embedding should be specified via + `reverse_editing_direction`. + reverse_editing_direction (`bool` or `List[bool]`, *optional*, defaults to `False`): Whether the corresponding prompt in `editing_prompt` should be increased or decreased. edit_guidance_scale (`float` or `List[float]`, *optional*, defaults to 5): Guidance scale for semantic guidance. If provided as list values should correspond to `editing_prompt`. @@ -230,7 +258,7 @@ def __call__( will still be calculated for those steps and applied once all warmup periods are over. edit_cooldown_steps (`float` or `List[float]`, *optional*, defaults to `None`): Number of diffusion steps (for each prompt) after which semantic guidance will no longer be applied. - edit_threshold (`float` or `List[float]`, *optional*, defaults to `None`): + edit_threshold (`float` or `List[float]`, *optional*, defaults to 0.9): Threshold of semantic guidance. edit_momentum_scale (`float`, *optional*, defaults to 0.1): Scale of the momentum to be added to the semantic guidance at each diffusion step. If set to 0.0 @@ -279,9 +307,9 @@ def __call__( if isinstance(editing_prompt, str): editing_prompt = [editing_prompt] enabled_editing_prompts = len(editing_prompt) - elif editing_prompt_prompt_embeddings is not None: + elif editing_prompt_embeddings is not None: enable_edit_guidance = True - enabled_editing_prompts = editing_prompt_prompt_embeddings.shape[0] + enabled_editing_prompts = editing_prompt_embeddings.shape[0] else: enabled_editing_prompts = 0 enable_edit_guidance = False @@ -311,7 +339,7 @@ def __call__( if enable_edit_guidance: # get safety text embeddings - if editing_prompt_prompt_embeddings is None: + if editing_prompt_embeddings is None: edit_concepts_input = self.tokenizer( [x for item in editing_prompt for x in repeat(item, batch_size)], padding="max_length", @@ -332,7 +360,7 @@ def __call__( edit_concepts_input_ids = edit_concepts_input_ids[:, : self.tokenizer.model_max_length] edit_concepts = self.text_encoder(edit_concepts_input_ids.to(self.device))[0] else: - edit_concepts = editing_prompt_prompt_embeddings.to(self.device).repeat(batch_size, 1, 1) + edit_concepts = editing_prompt_embeddings.to(self.device).repeat(batch_size, 1, 1) # duplicate text embeddings for each generation per prompt, using mps friendly method bs_embed_edit, seq_len_edit, _ = edit_concepts.shape From e6159b3c1859df5208fa6170f1624dbfdc88d9ad Mon Sep 17 00:00:00 2001 From: manuelbrack Date: Mon, 6 Feb 2023 18:55:42 +0100 Subject: [PATCH 05/14] Add documentation --- .../pipelines/semantic_stable_diffusion.mdx | 70 +++++++++++++++++++ .../api/pipelines/stable_diffusion_safe.mdx | 4 +- docs/source/en/index.mdx | 1 + .../pipeline_semantic_stable_diffusion.py | 6 +- 4 files changed, 76 insertions(+), 5 deletions(-) create mode 100644 docs/source/en/api/pipelines/semantic_stable_diffusion.mdx diff --git a/docs/source/en/api/pipelines/semantic_stable_diffusion.mdx b/docs/source/en/api/pipelines/semantic_stable_diffusion.mdx new file mode 100644 index 000000000000..420a7e385d51 --- /dev/null +++ b/docs/source/en/api/pipelines/semantic_stable_diffusion.mdx @@ -0,0 +1,70 @@ + + +# Semantic Guidance + +Semantic Guidance for Diffusion Models was proposed in [SEGA: Instructing Diffusion using Semantic Dimensions](https://arxiv.org/abs/2301.12247) and provides strong semantic control over the image generation. +Small changes to the text prompt usually result in entirely different output images. However, with SEGA a variety of changes to the image are enabled that can be controlled easily and intuitively, and stay true to the original image composition. + +The abstract of the paper is the following: + +*Text-to-image diffusion models have recently received a lot of interest for their astonishing ability to produce high-fidelity images from text only. However, achieving one-shot generation that aligns with the user's intent is nearly impossible, yet small changes to the input prompt often result in very different images. This leaves the user with little semantic control. To put the user in control, we show how to interact with the diffusion process to flexibly steer it along semantic directions. This semantic guidance (SEGA) allows for subtle and extensive edits, changes in composition and style, as well as optimizing the overall artistic conception. We demonstrate SEGA's effectiveness on a variety of tasks and provide evidence for its versatility and flexibility.* + + +*Overview*: + +| Pipeline | Tasks | Colab | Demo +|---|---|:---:|:---:| +| [pipeline_semantic_stable_diffusion.py](https://github.com/huggingface/diffusers/blob/main/src/diffusers/pipelines/semantic_stable_diffusion/pipeline_semantic_stable_diffusion) | *Text-to-Image Generation* | [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/ml-research/semantic-image-editing/blob/main/examples/SemanticGuidance.ipynb) | [Coming Soon](https://huggingface.co/AIML-TUDA) + +## Tips + +- The Semantic Guidance pipeline can be used with any [Stable Diffusion](./api/pipelines/stable_diffusion/text2img) checkpoint. + +### Run Semantic Guidance + +The interface of [`SemanticStableDiffusionPipeline`] provides several additional parameters to influence the image generation. +Exemplary usage may look like this: + +```python +import torch +from diffusers import SemanticStableDiffusionPipeline + +pipe = SemanticStableDiffusionPipeline.from_pretrained("runwayml/stable-diffusion-v1-5", torch_dtype=torch.float16) +pipe = pipe.to("cuda") + +out = pipe(prompt='a photo of the face of a woman', num_images_per_prompt=1, guidance_scale=7, + editing_prompt=['smiling, smile', # Concepts to apply + 'glasses, wearing glasses', + 'curls, wavy hair, curly hair', + 'beard, full beard, mustache'], + reverse_editing_direction=[False, False, False, False], # Direction of guidance i.e. increase all concepts + edit_warmup_steps=[10, 10, 10,10], # Warmup period for each concept + edit_guidance_scale=[4, 5, 5, 5.4], # Guidance scale for each concept + edit_threshold=[0.99, 0.975, 0.925, 0.96], # Threshold for each concept. Threshold equals the percentile of the latent space that will be discarded. I.e. threshold=0.99 uses 1% of the latent dimensions + edit_momentum_scale=0.3, # Momentum scale that will be added to the latent guidance + edit_mom_beta=0.6, # Momentum beta + edit_weights=[1,1,1,1,1] # Weights of the individual concepts against each other + ) +``` + +For more examples check the colab notebook. + +## StableDiffusionSafePipelineOutput +[[autodoc]] pipelines.semantic_stable_diffusion.SemanticStableDiffusionPipelineOutput + - all + - __call__ + +## SemanticStableDiffusionPipeline +[[autodoc]] SemanticStableDiffusionPipeline + - all + - __call__ diff --git a/docs/source/en/api/pipelines/stable_diffusion_safe.mdx b/docs/source/en/api/pipelines/stable_diffusion_safe.mdx index d2ae319a78e6..adac446ada88 100644 --- a/docs/source/en/api/pipelines/stable_diffusion_safe.mdx +++ b/docs/source/en/api/pipelines/stable_diffusion_safe.mdx @@ -24,7 +24,7 @@ The abstract of the paper is the following: | Pipeline | Tasks | Colab | Demo |---|---|:---:|:---:| -| [pipeline_stable_diffusion_safe.py](https://github.com/huggingface/diffusers/blob/main/src/diffusers/pipelines/stable_diffusion_safe/pipeline_stable_diffusion_safe.py) | *Text-to-Image Generation* | [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/ml-research/safe-latent-diffusion/blob/main/examples/Safe%20Latent%20Diffusion.ipynb) | - +| [pipeline_stable_diffusion_safe.py](https://github.com/huggingface/diffusers/blob/main/src/diffusers/pipelines/stable_diffusion_safe/pipeline_stable_diffusion_safe.py) | *Text-to-Image Generation* | [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/ml-research/safe-latent-diffusion/blob/main/examples/Safe%20Latent%20Diffusion.ipynb) | [![Huggingface Spaces](https://img.shields.io/badge/%F0%9F%A4%97%20Hugging%20Face-Spaces-blue)](https://huggingface.co/spaces/AIML-TUDA/unsafe-vs-safe-stable-diffusion) ## Tips @@ -58,7 +58,7 @@ You may use the 4 configurations defined in the [Safe Latent Diffusion paper](ht >>> out = pipeline(prompt=prompt, **SafetyConfig.MAX) ``` -The following configurations are available: `SafetyConfig.WEAK`, `SafetyConfig.MEDIUM`, `SafetyConfig.STRONg`, and `SafetyConfig.MAX`. +The following configurations are available: `SafetyConfig.WEAK`, `SafetyConfig.MEDIUM`, `SafetyConfig.STRONG`, and `SafetyConfig.MAX`. ### How to load and use different schedulers. diff --git a/docs/source/en/index.mdx b/docs/source/en/index.mdx index 148ee53f411f..6f1c8b8b3a61 100644 --- a/docs/source/en/index.mdx +++ b/docs/source/en/index.mdx @@ -47,6 +47,7 @@ available a colab notebook to directly try them out. | [pndm](./api/pipelines/pndm) | [**Pseudo Numerical Methods for Diffusion Models on Manifolds**](https://arxiv.org/abs/2202.09778) | Unconditional Image Generation | | [score_sde_ve](./api/pipelines/score_sde_ve) | [**Score-Based Generative Modeling through Stochastic Differential Equations**](https://openreview.net/forum?id=PxTIG12RRHS) | Unconditional Image Generation | | [score_sde_vp](./api/pipelines/score_sde_vp) | [**Score-Based Generative Modeling through Stochastic Differential Equations**](https://openreview.net/forum?id=PxTIG12RRHS) | Unconditional Image Generation | +| [semantic_stable_diffusion](./api/pipelines/semantic_stable_diffusion) | [**Semantic Guidance**](https://arxiv.org/abs/2301.12247) | Text-Guided Generation | [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/ml-research/semantic-image-editing/blob/main/examples/SemanticGuidance.ipynb) | [stable_diffusion](./api/pipelines/stable_diffusion/text2img) | [**Stable Diffusion**](https://stability.ai/blog/stable-diffusion-public-release) | Text-to-Image Generation | [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/huggingface/notebooks/blob/main/diffusers/training_example.ipynb) | [stable_diffusion](./api/pipelines/stable_diffusion/img2img) | [**Stable Diffusion**](https://stability.ai/blog/stable-diffusion-public-release) | Image-to-Image Text-Guided Generation | [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/huggingface/notebooks/blob/main/diffusers/image_2_image_using_diffusers.ipynb) | [stable_diffusion](./api/pipelines/stable_diffusion/inpaint) | [**Stable Diffusion**](https://stability.ai/blog/stable-diffusion-public-release) | Text-Guided Image Inpainting | [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/huggingface/notebooks/blob/main/diffusers/in_painting_with_stable_diffusion_using_diffusers.ipynb) diff --git a/src/diffusers/pipelines/semantic_stable_diffusion/pipeline_semantic_stable_diffusion.py b/src/diffusers/pipelines/semantic_stable_diffusion/pipeline_semantic_stable_diffusion.py index 3e2a9e065516..9317d968aa98 100644 --- a/src/diffusers/pipelines/semantic_stable_diffusion/pipeline_semantic_stable_diffusion.py +++ b/src/diffusers/pipelines/semantic_stable_diffusion/pipeline_semantic_stable_diffusion.py @@ -13,7 +13,7 @@ from ...pipeline_utils import DiffusionPipeline from ...pipelines.stable_diffusion.safety_checker import StableDiffusionSafetyChecker from ...schedulers import DDIMScheduler, LMSDiscreteScheduler, PNDMScheduler, KarrasDiffusionSchedulers -from ...utils import deprecate, logging +from ...utils import deprecate, logging, replace_example_docstring from . import SemanticStableDiffusionPipelineOutput @@ -27,7 +27,7 @@ >>> pipe = SemanticStableDiffusionPipeline.from_pretrained("runwayml/stable-diffusion-v1-5", torch_dtype=torch.float16) >>> pipe = pipe.to("cuda") - >>> out = pipe(prompt='a photo of the face of a woman', generator=gen, num_images_per_prompt=1, guidance_scale=7, + >>> out = pipe(prompt='a photo of the face of a woman', num_images_per_prompt=1, guidance_scale=7, >>> editing_prompt=['smiling, smile', # Concepts to apply 'glasses, wearing glasses', 'curls, wavy hair, curly hair', @@ -37,7 +37,7 @@ >>> edit_guidance_scale=[4, 5, 5, 5.4], # Guidance scale for each concept >>> edit_threshold=[0.99, 0.975, 0.925, 0.96], # Threshold for each concept. Threshold equals the percentile of the latent space that will be discarded. I.e. threshold=0.99 uses 1% of the latent dimensions >>> edit_momentum_scale=0.3, # Momentum scale that will be added to the latent guidance - >>> edit_beta1=0.6, # Momentum beta + >>> edit_mom_beta=0.6, # Momentum beta >>> edit_weights=[1,1,1,1,1] # Weights of the individual concepts against each other >>> ) >>> image = out.images[0] From dc422fef07c2bf814366e127f9b8cf29b2333546 Mon Sep 17 00:00:00 2001 From: manuelbrack Date: Mon, 6 Feb 2023 20:12:23 +0100 Subject: [PATCH 06/14] Fix style and quality --- .../pipelines/semantic_stable_diffusion.mdx | 37 +++++--- .../pipeline_semantic_stable_diffusion.py | 88 +++++++++++-------- 2 files changed, 75 insertions(+), 50 deletions(-) diff --git a/docs/source/en/api/pipelines/semantic_stable_diffusion.mdx b/docs/source/en/api/pipelines/semantic_stable_diffusion.mdx index 420a7e385d51..49137450677f 100644 --- a/docs/source/en/api/pipelines/semantic_stable_diffusion.mdx +++ b/docs/source/en/api/pipelines/semantic_stable_diffusion.mdx @@ -42,19 +42,29 @@ from diffusers import SemanticStableDiffusionPipeline pipe = SemanticStableDiffusionPipeline.from_pretrained("runwayml/stable-diffusion-v1-5", torch_dtype=torch.float16) pipe = pipe.to("cuda") -out = pipe(prompt='a photo of the face of a woman', num_images_per_prompt=1, guidance_scale=7, - editing_prompt=['smiling, smile', # Concepts to apply - 'glasses, wearing glasses', - 'curls, wavy hair, curly hair', - 'beard, full beard, mustache'], - reverse_editing_direction=[False, False, False, False], # Direction of guidance i.e. increase all concepts - edit_warmup_steps=[10, 10, 10,10], # Warmup period for each concept - edit_guidance_scale=[4, 5, 5, 5.4], # Guidance scale for each concept - edit_threshold=[0.99, 0.975, 0.925, 0.96], # Threshold for each concept. Threshold equals the percentile of the latent space that will be discarded. I.e. threshold=0.99 uses 1% of the latent dimensions - edit_momentum_scale=0.3, # Momentum scale that will be added to the latent guidance - edit_mom_beta=0.6, # Momentum beta - edit_weights=[1,1,1,1,1] # Weights of the individual concepts against each other - ) +out = pipe( + prompt="a photo of the face of a woman", + num_images_per_prompt=1, + guidance_scale=7, + editing_prompt=[ + "smiling, smile", # Concepts to apply + "glasses, wearing glasses", + "curls, wavy hair, curly hair", + "beard, full beard, mustache", + ], + reverse_editing_direction=[False, False, False, False], # Direction of guidance i.e. increase all concepts + edit_warmup_steps=[10, 10, 10, 10], # Warmup period for each concept + edit_guidance_scale=[4, 5, 5, 5.4], # Guidance scale for each concept + edit_threshold=[ + 0.99, + 0.975, + 0.925, + 0.96, + ], # Threshold for each concept. Threshold equals the percentile of the latent space that will be discarded. I.e. threshold=0.99 uses 1% of the latent dimensions + edit_momentum_scale=0.3, # Momentum scale that will be added to the latent guidance + edit_mom_beta=0.6, # Momentum beta + edit_weights=[1, 1, 1, 1, 1], # Weights of the individual concepts against each other +) ``` For more examples check the colab notebook. @@ -62,7 +72,6 @@ For more examples check the colab notebook. ## StableDiffusionSafePipelineOutput [[autodoc]] pipelines.semantic_stable_diffusion.SemanticStableDiffusionPipelineOutput - all - - __call__ ## SemanticStableDiffusionPipeline [[autodoc]] SemanticStableDiffusionPipeline diff --git a/src/diffusers/pipelines/semantic_stable_diffusion/pipeline_semantic_stable_diffusion.py b/src/diffusers/pipelines/semantic_stable_diffusion/pipeline_semantic_stable_diffusion.py index 9317d968aa98..b4c2490c0895 100644 --- a/src/diffusers/pipelines/semantic_stable_diffusion/pipeline_semantic_stable_diffusion.py +++ b/src/diffusers/pipelines/semantic_stable_diffusion/pipeline_semantic_stable_diffusion.py @@ -1,19 +1,18 @@ import inspect -import warnings from itertools import repeat from typing import Callable, List, Optional, Union -from packaging import version import torch +from packaging import version from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer from ...configuration_utils import FrozenDict from ...models import AutoencoderKL, UNet2DConditionModel from ...pipeline_utils import DiffusionPipeline from ...pipelines.stable_diffusion.safety_checker import StableDiffusionSafetyChecker -from ...schedulers import DDIMScheduler, LMSDiscreteScheduler, PNDMScheduler, KarrasDiffusionSchedulers -from ...utils import deprecate, logging, replace_example_docstring +from ...schedulers import KarrasDiffusionSchedulers +from ...utils import deprecate, logging from . import SemanticStableDiffusionPipelineOutput @@ -24,26 +23,44 @@ >>> import torch >>> from diffusers import SemanticStableDiffusionPipeline - >>> pipe = SemanticStableDiffusionPipeline.from_pretrained("runwayml/stable-diffusion-v1-5", torch_dtype=torch.float16) + >>> pipe = SemanticStableDiffusionPipeline.from_pretrained( + ... "runwayml/stable-diffusion-v1-5", torch_dtype=torch.float16 + ... ) >>> pipe = pipe.to("cuda") - >>> out = pipe(prompt='a photo of the face of a woman', num_images_per_prompt=1, guidance_scale=7, - >>> editing_prompt=['smiling, smile', # Concepts to apply - 'glasses, wearing glasses', - 'curls, wavy hair, curly hair', - 'beard, full beard, mustache'], - >>> reverse_editing_direction=[False, False, False, False], # Direction of guidance i.e. increase all concepts - >>> edit_warmup_steps=[10, 10, 10,10], # Warmup period for each concept - >>> edit_guidance_scale=[4, 5, 5, 5.4], # Guidance scale for each concept - >>> edit_threshold=[0.99, 0.975, 0.925, 0.96], # Threshold for each concept. Threshold equals the percentile of the latent space that will be discarded. I.e. threshold=0.99 uses 1% of the latent dimensions - >>> edit_momentum_scale=0.3, # Momentum scale that will be added to the latent guidance - >>> edit_mom_beta=0.6, # Momentum beta - >>> edit_weights=[1,1,1,1,1] # Weights of the individual concepts against each other - >>> ) + >>> out = pipe( + ... prompt="a photo of the face of a woman", + ... num_images_per_prompt=1, + ... guidance_scale=7, + ... editing_prompt=[ + ... "smiling, smile", # Concepts to apply + ... "glasses, wearing glasses", + ... "curls, wavy hair, curly hair", + ... "beard, full beard, mustache", + ... ], + ... reverse_editing_direction=[ + ... False, + ... False, + ... False, + ... False, + ... ], # Direction of guidance i.e. increase all concepts + ... edit_warmup_steps=[10, 10, 10, 10], # Warmup period for each concept + ... edit_guidance_scale=[4, 5, 5, 5.4], # Guidance scale for each concept + ... edit_threshold=[ + ... 0.99, + ... 0.975, + ... 0.925, + ... 0.96, + ... ], # Threshold for each concept. Threshold equals the percentile of the latent space that will be discarded. I.e. threshold=0.99 uses 1% of the latent dimensions + ... edit_momentum_scale=0.3, # Momentum scale that will be added to the latent guidance + ... edit_mom_beta=0.6, # Momentum beta + ... edit_weights=[1, 1, 1, 1, 1], # Weights of the individual concepts against each other + ... ) >>> image = out.images[0] ``` """ + class SemanticStableDiffusionPipeline(DiffusionPipeline): r""" Pipeline for text-to-image generation with latent editing. @@ -77,15 +94,15 @@ class SemanticStableDiffusionPipeline(DiffusionPipeline): _optional_components = ["safety_checker", "feature_extractor"] def __init__( - self, - vae: AutoencoderKL, - text_encoder: CLIPTextModel, - tokenizer: CLIPTokenizer, - unet: UNet2DConditionModel, - scheduler: KarrasDiffusionSchedulers, - safety_checker: StableDiffusionSafetyChecker, - feature_extractor: CLIPFeatureExtractor, - requires_safety_checker: bool = True, + self, + vae: AutoencoderKL, + text_encoder: CLIPTextModel, + tokenizer: CLIPTokenizer, + unet: UNet2DConditionModel, + scheduler: KarrasDiffusionSchedulers, + safety_checker: StableDiffusionSafetyChecker, + feature_extractor: CLIPFeatureExtractor, + requires_safety_checker: bool = True, ): super().__init__() @@ -166,7 +183,6 @@ def __init__( self.register_to_config(requires_safety_checker=requires_safety_checker) @torch.no_grad() - @replace_example_docstring(EXAMPLE_DOC_STRING) def __call__( self, prompt: Union[str, List[str]], @@ -184,7 +200,7 @@ def __call__( callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None, callback_steps: Optional[int] = 1, editing_prompt: Optional[Union[str, List[str]]] = None, - editing_prompt_embeddings: Optional[torch.FloatTensor]=None, + editing_prompt_embeddings: Optional[torch.FloatTensor] = None, reverse_editing_direction: Optional[Union[bool, List[bool]]] = False, edit_guidance_scale: Optional[Union[float, List[float]]] = 5, edit_warmup_steps: Optional[Union[int, List[int]]] = 10, @@ -247,8 +263,8 @@ def __call__( `editing_prompt = None`. Guidance direction of prompt should be specified via `reverse_editing_direction`. editing_prompt_embeddings (`torch.FloatTensor`, *optional*): - Pre-computed embeddings to use for semantic guidance. Guidance direction of embedding should be specified via - `reverse_editing_direction`. + Pre-computed embeddings to use for semantic guidance. Guidance direction of embedding should be + specified via `reverse_editing_direction`. reverse_editing_direction (`bool` or `List[bool]`, *optional*, defaults to `False`): Whether the corresponding prompt in `editing_prompt` should be increased or decreased. edit_guidance_scale (`float` or `List[float]`, *optional*, defaults to 5): @@ -277,11 +293,11 @@ def __call__( correspond to `num_inference_steps`. Returns: - [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`: - [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] if `return_dict` is True, otherwise a `tuple. - When returning a tuple, the first element is a list with the generated images, and the second element is a - list of `bool`s denoting whether the corresponding generated image likely represents "not-safe-for-work" - (nsfw) content, according to the `safety_checker`. + [`~pipelines.semantic_stable_diffusion.SemanticStableDiffusionPipelineOutput`] or `tuple`: + [`~pipelines.semantic_stable_diffusion.SemanticStableDiffusionPipelineOutput`] if `return_dict` is True, + otherwise a `tuple. When returning a tuple, the first element is a list with the generated images, and the + second element is a list of `bool`s denoting whether the corresponding generated image likely represents + "not-safe-for-work" (nsfw) content, according to the `safety_checker`. """ if isinstance(prompt, str): From 8a157de5a99739c163f43126c70a94aee7f0c609 Mon Sep 17 00:00:00 2001 From: manuelbrack Date: Mon, 6 Feb 2023 20:23:59 +0100 Subject: [PATCH 07/14] Fix doctree --- docs/source/en/_toctree.yml | 2 ++ 1 file changed, 2 insertions(+) diff --git a/docs/source/en/_toctree.yml b/docs/source/en/_toctree.yml index 353174149d82..fe6566f8328b 100644 --- a/docs/source/en/_toctree.yml +++ b/docs/source/en/_toctree.yml @@ -126,6 +126,8 @@ title: PNDM - local: api/pipelines/repaint title: RePaint + - local: api/pipelines/semantic_stable_diffusion + title: Semantic Guidance - local: api/pipelines/stable_diffusion_safe title: Safe Stable Diffusion - local: api/pipelines/score_sde_ve From e80aea339b20716cb53020ebcc705bc93b8778ec Mon Sep 17 00:00:00 2001 From: manuelbrack Date: Mon, 13 Feb 2023 12:17:29 +0100 Subject: [PATCH 08/14] Add tests for SEGA --- .../pipeline_semantic_stable_diffusion.py | 208 ++++--- .../semantic_stable_diffusion/__init__.py | 0 .../test_semantic_diffusion.py | 518 ++++++++++++++++++ 3 files changed, 656 insertions(+), 70 deletions(-) create mode 100644 tests/pipelines/semantic_stable_diffusion/__init__.py create mode 100644 tests/pipelines/semantic_stable_diffusion/test_semantic_diffusion.py diff --git a/src/diffusers/pipelines/semantic_stable_diffusion/pipeline_semantic_stable_diffusion.py b/src/diffusers/pipelines/semantic_stable_diffusion/pipeline_semantic_stable_diffusion.py index b4c2490c0895..8d304550974f 100644 --- a/src/diffusers/pipelines/semantic_stable_diffusion/pipeline_semantic_stable_diffusion.py +++ b/src/diffusers/pipelines/semantic_stable_diffusion/pipeline_semantic_stable_diffusion.py @@ -12,7 +12,7 @@ from ...pipeline_utils import DiffusionPipeline from ...pipelines.stable_diffusion.safety_checker import StableDiffusionSafetyChecker from ...schedulers import KarrasDiffusionSchedulers -from ...utils import deprecate, logging +from ...utils import deprecate, logging, randn_tensor from . import SemanticStableDiffusionPipelineOutput @@ -182,18 +182,111 @@ def __init__( self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) self.register_to_config(requires_safety_checker=requires_safety_checker) + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.decode_latents + def decode_latents(self, latents): + latents = 1 / self.vae.config.scaling_factor * latents + image = self.vae.decode(latents).sample + image = (image / 2 + 0.5).clamp(0, 1) + # we always cast to float32 as this does not cause significant overhead and is compatible with bfloa16 + image = image.cpu().permute(0, 2, 3, 1).float().numpy() + return image + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs + def prepare_extra_step_kwargs(self, generator, eta): + # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature + # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers. + # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502 + # and should be between [0, 1] + + accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys()) + extra_step_kwargs = {} + if accepts_eta: + extra_step_kwargs["eta"] = eta + + # check if the scheduler accepts generator + accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys()) + if accepts_generator: + extra_step_kwargs["generator"] = generator + return extra_step_kwargs + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.check_inputs + def check_inputs( + self, + prompt, + height, + width, + callback_steps, + negative_prompt=None, + prompt_embeds=None, + negative_prompt_embeds=None, + ): + if height % 8 != 0 or width % 8 != 0: + raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.") + + if (callback_steps is None) or ( + callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0) + ): + raise ValueError( + f"`callback_steps` has to be a positive integer but is {callback_steps} of type" + f" {type(callback_steps)}." + ) + + if prompt is not None and prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to" + " only forward one of the two." + ) + elif prompt is None and prompt_embeds is None: + raise ValueError( + "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined." + ) + elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)): + raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") + + if negative_prompt is not None and negative_prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:" + f" {negative_prompt_embeds}. Please make sure to only forward one of the two." + ) + + if prompt_embeds is not None and negative_prompt_embeds is not None: + if prompt_embeds.shape != negative_prompt_embeds.shape: + raise ValueError( + "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but" + f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`" + f" {negative_prompt_embeds.shape}." + ) + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_latents + def prepare_latents(self, batch_size, num_channels_latents, height, width, dtype, device, generator, latents=None): + shape = (batch_size, num_channels_latents, height // self.vae_scale_factor, width // self.vae_scale_factor) + if isinstance(generator, list) and len(generator) != batch_size: + raise ValueError( + f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" + f" size of {batch_size}. Make sure the batch size matches the length of the generators." + ) + + if latents is None: + latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) + else: + latents = latents.to(device) + + # scale the initial noise by the standard deviation required by the scheduler + latents = latents * self.scheduler.init_noise_sigma + return latents + @torch.no_grad() def __call__( self, prompt: Union[str, List[str]], - height: int = 512, - width: int = 512, + height: Optional[int] = None, + width: Optional[int] = None, num_inference_steps: int = 50, guidance_scale: float = 7.5, negative_prompt: Optional[Union[str, List[str]]] = None, num_images_per_prompt: Optional[int] = 1, eta: float = 0.0, - generator: Optional[torch.Generator] = None, + generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, latents: Optional[torch.FloatTensor] = None, output_type: Optional[str] = "pil", return_dict: bool = True, @@ -218,9 +311,9 @@ def __call__( Args: prompt (`str` or `List[str]`): The prompt or prompts to guide the image generation. - height (`int`, *optional*, defaults to 512): + height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor): The height in pixels of the generated image. - width (`int`, *optional*, defaults to 512): + width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor): The width in pixels of the generated image. num_inference_steps (`int`, *optional*, defaults to 50): The number of denoising steps. More denoising steps usually lead to a higher quality image at the @@ -240,8 +333,8 @@ def __call__( Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to [`schedulers.DDIMScheduler`], will be ignored for others. generator (`torch.Generator`, *optional*): - A [torch generator](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make generation - deterministic. + One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html) + to make generation deterministic. latents (`torch.FloatTensor`, *optional*): Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image generation. Can be used to tweak the same generation with different prompts. If not provided, a latents @@ -269,9 +362,12 @@ def __call__( Whether the corresponding prompt in `editing_prompt` should be increased or decreased. edit_guidance_scale (`float` or `List[float]`, *optional*, defaults to 5): Guidance scale for semantic guidance. If provided as list values should correspond to `editing_prompt`. + `edit_guidance_scale` is defined as `s_e` of equation 6 of [SEGA + Paper](https://arxiv.org/pdf/2301.12247.pdf). edit_warmup_steps (`float` or `List[float]`, *optional*, defaults to 10): Number of diffusion steps (for each prompt) for which semantic guidance will not be applied. Momentum will still be calculated for those steps and applied once all warmup periods are over. + `edit_warmup_steps` is defined as `delta` (δ) of [SEGA Paper](https://arxiv.org/pdf/2301.12247.pdf). edit_cooldown_steps (`float` or `List[float]`, *optional*, defaults to `None`): Number of diffusion steps (for each prompt) after which semantic guidance will no longer be applied. edit_threshold (`float` or `List[float]`, *optional*, defaults to 0.9): @@ -280,14 +376,17 @@ def __call__( Scale of the momentum to be added to the semantic guidance at each diffusion step. If set to 0.0 momentum will be disabled. Momentum is already built up during warmup, i.e. for diffusion steps smaller than `sld_warmup_steps`. Momentum will only be added to latent guidance once all warmup periods are - finished. + finished. `edit_momentum_scale` is defined as `s_m` of equation 7 of [SEGA + Paper](https://arxiv.org/pdf/2301.12247.pdf). edit_mom_beta (`float`, *optional*, defaults to 0.4): Defines how semantic guidance momentum builds up. `edit_mom_beta` indicates how much of the previous momentum will be kept. Momentum is already built up during warmup, i.e. for diffusion steps smaller - than `edit_warmup_steps`. + than `edit_warmup_steps`. `edit_mom_beta` is defined as `beta_m` (β) of equation 8 of [SEGA + Paper](https://arxiv.org/pdf/2301.12247.pdf). edit_weights (`List[float]`, *optional*, defaults to `None`): Indicates how much each individual concept should influence the overall guidance. If no weights are - provided all concepts are applied equally. + provided all concepts are applied equally. `edit_mom_beta` is defined as `g_i` of equation 9 of [SEGA + Paper](https://arxiv.org/pdf/2301.12247.pdf). sem_guidance (`List[torch.FloatTensor]`, *optional*): List of pre-generated guidance vectors to be applied at generation. Length of the list has to correspond to `num_inference_steps`. @@ -299,24 +398,15 @@ def __call__( second element is a list of `bool`s denoting whether the corresponding generated image likely represents "not-safe-for-work" (nsfw) content, according to the `safety_checker`. """ + # 0. Default height and width to unet + height = height or self.unet.config.sample_size * self.vae_scale_factor + width = width or self.unet.config.sample_size * self.vae_scale_factor - if isinstance(prompt, str): - batch_size = 1 - elif isinstance(prompt, list): - batch_size = len(prompt) - else: - raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") - - if height % 8 != 0 or width % 8 != 0: - raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.") + # 1. Check inputs. Raise error if not correct + self.check_inputs(prompt, height, width, callback_steps) - if (callback_steps is None) or ( - callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0) - ): - raise ValueError( - f"`callback_steps` has to be a positive integer but is {callback_steps} of type" - f" {type(callback_steps)}." - ) + # 2. Define call parameters + batch_size = 1 if isinstance(prompt, str) else len(prompt) if editing_prompt: enable_edit_guidance = True @@ -433,42 +523,25 @@ def __call__( text_embeddings = torch.cat([uncond_embeddings, text_embeddings]) # get the initial random noise unless the user supplied it - # Unlike in other pipelines, latents need to be generated in the target device - # for 1-to-1 results reproducibility with the CompVis implementation. - # However this currently doesn't work in `mps`. - latents_shape = (batch_size * num_images_per_prompt, self.unet.in_channels, height // 8, width // 8) - latents_dtype = text_embeddings.dtype - if latents is None: - if self.device.type == "mps": - # randn does not exist on mps - latents = torch.randn(latents_shape, generator=generator, device="cpu", dtype=latents_dtype).to( - self.device - ) - else: - latents = torch.randn(latents_shape, generator=generator, device=self.device, dtype=latents_dtype) - else: - if latents.shape != latents_shape: - raise ValueError(f"Unexpected latents shape, got {latents.shape}, expected {latents_shape}") - latents = latents.to(self.device) - - # set timesteps - self.scheduler.set_timesteps(num_inference_steps) - - # Some schedulers like PNDM have timesteps as arrays - # It's more optimized to move all timesteps to correct device beforehand - timesteps_tensor = self.scheduler.timesteps.to(self.device) + # 4. Prepare timesteps + self.scheduler.set_timesteps(num_inference_steps, device=self.device) + timesteps = self.scheduler.timesteps + + # 5. Prepare latent variables + num_channels_latents = self.unet.in_channels + latents = self.prepare_latents( + batch_size * num_images_per_prompt, + num_channels_latents, + height, + width, + text_embeddings.dtype, + self.device, + generator, + latents, + ) - # scale the initial noise by the standard deviation required by the scheduler - latents = latents * self.scheduler.init_noise_sigma - - # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature - # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers. - # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502 - # and should be between [0, 1] - accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys()) - extra_step_kwargs = {} - if accepts_eta: - extra_step_kwargs["eta"] = eta + # 6. Prepare extra step kwargs. + extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta) # Initialize edit_momentum to None edit_momentum = None @@ -478,7 +551,7 @@ def __call__( self.edit_estimates = None self.sem_guidance = None - for i, t in enumerate(self.progress_bar(timesteps_tensor)): + for i, t in enumerate(self.progress_bar(timesteps)): # expand the latents if we are doing classifier free guidance latent_model_input = ( torch.cat([latents] * (2 + enabled_editing_prompts)) if do_classifier_free_guidance else latents @@ -643,13 +716,8 @@ def __call__( if callback is not None and i % callback_steps == 0: callback(i, t, latents) - latents = 1 / 0.18215 * latents - image = self.vae.decode(latents).sample - - image = (image / 2 + 0.5).clamp(0, 1) - - # we always cast to float32 as this does not cause significant overhead and is compatible with bfloa16 - image = image.cpu().permute(0, 2, 3, 1).float().numpy() + # 8. Post-processing + image = self.decode_latents(latents) if self.safety_checker is not None: safety_checker_input = self.feature_extractor(self.numpy_to_pil(image), return_tensors="pt").to( diff --git a/tests/pipelines/semantic_stable_diffusion/__init__.py b/tests/pipelines/semantic_stable_diffusion/__init__.py new file mode 100644 index 000000000000..e69de29bb2d1 diff --git a/tests/pipelines/semantic_stable_diffusion/test_semantic_diffusion.py b/tests/pipelines/semantic_stable_diffusion/test_semantic_diffusion.py new file mode 100644 index 000000000000..c7e36610b438 --- /dev/null +++ b/tests/pipelines/semantic_stable_diffusion/test_semantic_diffusion.py @@ -0,0 +1,518 @@ +# coding=utf-8 +# Copyright 2022 HuggingFace Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import gc +import random +import tempfile +import unittest + +import numpy as np +import torch + +from diffusers import AutoencoderKL, DDIMScheduler, LMSDiscreteScheduler, PNDMScheduler, UNet2DConditionModel +from diffusers.pipelines.semantic_stable_diffusion import SemanticStableDiffusionPipeline as StableDiffusionPipeline +from diffusers.utils import floats_tensor, nightly, torch_device +from diffusers.utils.testing_utils import require_torch_gpu +from transformers import CLIPTextConfig, CLIPTextModel, CLIPTokenizer + + +torch.backends.cuda.matmul.allow_tf32 = False + + +class SafeDiffusionPipelineFastTests(unittest.TestCase): + def tearDown(self): + # clean up the VRAM after each test + super().tearDown() + gc.collect() + torch.cuda.empty_cache() + + @property + def dummy_image(self): + batch_size = 1 + num_channels = 3 + sizes = (32, 32) + + image = floats_tensor((batch_size, num_channels) + sizes, rng=random.Random(0)).to(torch_device) + return image + + @property + def dummy_cond_unet(self): + torch.manual_seed(0) + model = UNet2DConditionModel( + block_out_channels=(32, 64), + layers_per_block=2, + sample_size=32, + in_channels=4, + out_channels=4, + down_block_types=("DownBlock2D", "CrossAttnDownBlock2D"), + up_block_types=("CrossAttnUpBlock2D", "UpBlock2D"), + cross_attention_dim=32, + ) + return model + + @property + def dummy_vae(self): + torch.manual_seed(0) + model = AutoencoderKL( + block_out_channels=[32, 64], + in_channels=3, + out_channels=3, + down_block_types=["DownEncoderBlock2D", "DownEncoderBlock2D"], + up_block_types=["UpDecoderBlock2D", "UpDecoderBlock2D"], + latent_channels=4, + ) + return model + + @property + def dummy_text_encoder(self): + torch.manual_seed(0) + config = CLIPTextConfig( + bos_token_id=0, + eos_token_id=2, + hidden_size=32, + intermediate_size=37, + layer_norm_eps=1e-05, + num_attention_heads=4, + num_hidden_layers=5, + pad_token_id=1, + vocab_size=1000, + ) + return CLIPTextModel(config) + + @property + def dummy_extractor(self): + def extract(*args, **kwargs): + class Out: + def __init__(self): + self.pixel_values = torch.ones([0]) + + def to(self, device): + self.pixel_values.to(device) + return self + + return Out() + + return extract + + def test_semantic_diffusion_ddim(self): + device = "cpu" # ensure determinism for the device-dependent torch.Generator + unet = self.dummy_cond_unet + scheduler = DDIMScheduler( + beta_start=0.00085, + beta_end=0.012, + beta_schedule="scaled_linear", + clip_sample=False, + set_alpha_to_one=False, + ) + + vae = self.dummy_vae + bert = self.dummy_text_encoder + tokenizer = CLIPTokenizer.from_pretrained("hf-internal-testing/tiny-random-clip") + + # make sure here that pndm scheduler skips prk + sd_pipe = StableDiffusionPipeline( + unet=unet, + scheduler=scheduler, + vae=vae, + text_encoder=bert, + tokenizer=tokenizer, + safety_checker=None, + feature_extractor=self.dummy_extractor, + ) + sd_pipe = sd_pipe.to(device) + sd_pipe.set_progress_bar_config(disable=None) + + prompt = "A painting of a squirrel eating a burger" + + generator = torch.Generator(device=device).manual_seed(0) + output = sd_pipe([prompt], generator=generator, guidance_scale=6.0, num_inference_steps=2, output_type="np") + image = output.images + + generator = torch.Generator(device=device).manual_seed(0) + image_from_tuple = sd_pipe( + [prompt], + generator=generator, + guidance_scale=6.0, + num_inference_steps=2, + output_type="np", + return_dict=False, + )[0] + + image_slice = image[0, -3:, -3:, -1] + image_from_tuple_slice = image_from_tuple[0, -3:, -3:, -1] + + assert image.shape == (1, 64, 64, 3) + expected_slice = np.array([0.5644, 0.6018, 0.4799, 0.5267, 0.5585, 0.4641, 0.516, 0.4964, 0.4792]) + + assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2 + assert np.abs(image_from_tuple_slice.flatten() - expected_slice).max() < 1e-2 + + def test_semantic_diffusion_pndm(self): + device = "cpu" # ensure determinism for the device-dependent torch.Generator + unet = self.dummy_cond_unet + scheduler = PNDMScheduler(skip_prk_steps=True) + vae = self.dummy_vae + bert = self.dummy_text_encoder + tokenizer = CLIPTokenizer.from_pretrained("hf-internal-testing/tiny-random-clip") + + # make sure here that pndm scheduler skips prk + sd_pipe = StableDiffusionPipeline( + unet=unet, + scheduler=scheduler, + vae=vae, + text_encoder=bert, + tokenizer=tokenizer, + safety_checker=None, + feature_extractor=self.dummy_extractor, + ) + sd_pipe = sd_pipe.to(device) + sd_pipe.set_progress_bar_config(disable=None) + + prompt = "A painting of a squirrel eating a burger" + generator = torch.Generator(device=device).manual_seed(0) + output = sd_pipe([prompt], generator=generator, guidance_scale=6.0, num_inference_steps=2, output_type="np") + + image = output.images + + generator = torch.Generator(device=device).manual_seed(0) + image_from_tuple = sd_pipe( + [prompt], + generator=generator, + guidance_scale=6.0, + num_inference_steps=2, + output_type="np", + return_dict=False, + )[0] + + image_slice = image[0, -3:, -3:, -1] + image_from_tuple_slice = image_from_tuple[0, -3:, -3:, -1] + + assert image.shape == (1, 64, 64, 3) + expected_slice = np.array([0.5095, 0.5674, 0.4668, 0.5126, 0.5697, 0.4675, 0.5278, 0.4964, 0.4945]) + + assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2 + assert np.abs(image_from_tuple_slice.flatten() - expected_slice).max() < 1e-2 + + def test_semantic_diffusion_no_safety_checker(self): + pipe = StableDiffusionPipeline.from_pretrained( + "hf-internal-testing/tiny-stable-diffusion-lms-pipe", safety_checker=None + ) + assert isinstance(pipe, StableDiffusionPipeline) + assert isinstance(pipe.scheduler, LMSDiscreteScheduler) + assert pipe.safety_checker is None + + image = pipe("example prompt", num_inference_steps=2).images[0] + assert image is not None + + # check that there's no error when saving a pipeline with one of the models being None + with tempfile.TemporaryDirectory() as tmpdirname: + pipe.save_pretrained(tmpdirname) + pipe = StableDiffusionPipeline.from_pretrained(tmpdirname) + + # sanity check that the pipeline still works + assert pipe.safety_checker is None + image = pipe("example prompt", num_inference_steps=2).images[0] + assert image is not None + + @unittest.skipIf(torch_device != "cuda", "This test requires a GPU") + def test_semantic_diffusion_fp16(self): + """Test that stable diffusion works with fp16""" + unet = self.dummy_cond_unet + scheduler = PNDMScheduler(skip_prk_steps=True) + vae = self.dummy_vae + bert = self.dummy_text_encoder + tokenizer = CLIPTokenizer.from_pretrained("hf-internal-testing/tiny-random-clip") + + # put models in fp16 + unet = unet.half() + vae = vae.half() + bert = bert.half() + + # make sure here that pndm scheduler skips prk + sd_pipe = StableDiffusionPipeline( + unet=unet, + scheduler=scheduler, + vae=vae, + text_encoder=bert, + tokenizer=tokenizer, + safety_checker=None, + feature_extractor=self.dummy_extractor, + ) + sd_pipe = sd_pipe.to(torch_device) + sd_pipe.set_progress_bar_config(disable=None) + + prompt = "A painting of a squirrel eating a burger" + image = sd_pipe([prompt], num_inference_steps=2, output_type="np").images + + assert image.shape == (1, 64, 64, 3) + + +@nightly +@require_torch_gpu +class SemanticDiffusionPipelineIntegrationTests(unittest.TestCase): + def tearDown(self): + # clean up the VRAM after each test + super().tearDown() + gc.collect() + torch.cuda.empty_cache() + + def test_positive_guidance(self): + torch_device = "cuda" + pipe = StableDiffusionPipeline.from_pretrained("runwayml/stable-diffusion-v1-5") + pipe = pipe.to(torch_device) + pipe.set_progress_bar_config(disable=None) + + prompt = "a photo of a cat" + edit = { + "editing_prompt": ["sunglasses"], + "reverse_editing_direction": [False], + "edit_warmup_steps": 10, + "edit_guidance_scale": 6, + "edit_threshold": 0.95, + "edit_momentum_scale": 0.5, + "edit_mom_beta": 0.6, + } + + seed = 3 + guidance_scale = 7 + + # no sega enabled + generator = torch.Generator(torch_device) + generator.manual_seed(seed) + output = pipe( + [prompt], + generator=generator, + guidance_scale=guidance_scale, + num_inference_steps=50, + output_type="np", + width=512, + height=512, + ) + + image = output.images + image_slice = image[0, -3:, -3:, -1] + expected_slice = [ + 0.34673113, + 0.38492733, + 0.37597352, + 0.34086335, + 0.35650748, + 0.35579205, + 0.3384763, + 0.34340236, + 0.3573271, + ] + + assert image.shape == (1, 512, 512, 3) + + assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2 + + # with sega enabled + # generator = torch.manual_seed(seed) + generator.manual_seed(seed) + output = pipe( + [prompt], + generator=generator, + guidance_scale=guidance_scale, + num_inference_steps=50, + output_type="np", + width=512, + height=512, + **edit, + ) + + image = output.images + image_slice = image[0, -3:, -3:, -1] + expected_slice = [ + 0.41887826, + 0.37728766, + 0.30138272, + 0.41416335, + 0.41664985, + 0.36283392, + 0.36191246, + 0.43364465, + 0.43001732, + ] + + assert image.shape == (1, 512, 512, 3) + + assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2 + + def test_negative_guidance(self): + torch_device = "cuda" + pipe = StableDiffusionPipeline.from_pretrained("runwayml/stable-diffusion-v1-5") + pipe = pipe.to(torch_device) + pipe.set_progress_bar_config(disable=None) + + prompt = "an image of a crowded boulevard, realistic, 4k" + edit = { + "editing_prompt": "crowd, crowded, people", + "reverse_editing_direction": True, + "edit_warmup_steps": 10, + "edit_guidance_scale": 8.3, + "edit_threshold": 0.9, + "edit_momentum_scale": 0.5, + "edit_mom_beta": 0.6, + } + + seed = 9 + guidance_scale = 7 + + # no sega enabled + generator = torch.Generator(torch_device) + generator.manual_seed(seed) + output = pipe( + [prompt], + generator=generator, + guidance_scale=guidance_scale, + num_inference_steps=50, + output_type="np", + width=512, + height=512, + ) + + image = output.images + image_slice = image[0, -3:, -3:, -1] + expected_slice = [ + 0.43497998, + 0.91814065, + 0.7540739, + 0.55580205, + 0.8467265, + 0.5389691, + 0.62574506, + 0.58897763, + 0.50926757, + ] + + assert image.shape == (1, 512, 512, 3) + + assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2 + + # with sega enabled + # generator = torch.manual_seed(seed) + generator.manual_seed(seed) + output = pipe( + [prompt], + generator=generator, + guidance_scale=guidance_scale, + num_inference_steps=50, + output_type="np", + width=512, + height=512, + **edit, + ) + + image = output.images + image_slice = image[0, -3:, -3:, -1] + expected_slice = [ + 0.3089719, + 0.30500144, + 0.29016042, + 0.30630964, + 0.325687, + 0.29419225, + 0.2908091, + 0.28723598, + 0.27696294, + ] + + assert image.shape == (1, 512, 512, 3) + + assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2 + + def test_multi_cond_guidance(self): + torch_device = "cuda" + pipe = StableDiffusionPipeline.from_pretrained("runwayml/stable-diffusion-v1-5") + pipe = pipe.to(torch_device) + pipe.set_progress_bar_config(disable=None) + + prompt = "a castle next to a river" + edit = { + "editing_prompt": ["boat on a river, boat", "monet, impression, sunrise"], + "reverse_editing_direction": False, + "edit_warmup_steps": [15, 18], + "edit_guidance_scale": 6, + "edit_threshold": [0.9, 0.8], + "edit_momentum_scale": 0.5, + "edit_mom_beta": 0.6, + } + + seed = 48 + guidance_scale = 7 + + # no sega enabled + generator = torch.Generator(torch_device) + generator.manual_seed(seed) + output = pipe( + [prompt], + generator=generator, + guidance_scale=guidance_scale, + num_inference_steps=50, + output_type="np", + width=512, + height=512, + ) + + image = output.images + image_slice = image[0, -3:, -3:, -1] + expected_slice = [ + 0.75163555, + 0.76037145, + 0.61785, + 0.9189673, + 0.8627701, + 0.85189694, + 0.8512813, + 0.87012076, + 0.8312857, + ] + + assert image.shape == (1, 512, 512, 3) + + assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2 + + # with sega enabled + # generator = torch.manual_seed(seed) + generator.manual_seed(seed) + output = pipe( + [prompt], + generator=generator, + guidance_scale=guidance_scale, + num_inference_steps=50, + output_type="np", + width=512, + height=512, + **edit, + ) + + image = output.images + image_slice = image[0, -3:, -3:, -1] + expected_slice = [ + 0.73553365, + 0.7537271, + 0.74341905, + 0.66480356, + 0.6472925, + 0.63039416, + 0.64812905, + 0.6749717, + 0.6517102, + ] + + assert image.shape == (1, 512, 512, 3) + + assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2 From 771c78d27d86fb5886bfc989e52b9eb6cdbcd806 Mon Sep 17 00:00:00 2001 From: Manuel Brack Date: Thu, 16 Feb 2023 13:52:08 +0100 Subject: [PATCH 09/14] Update src/diffusers/pipelines/semantic_stable_diffusion/pipeline_semantic_stable_diffusion.py Co-authored-by: Patrick von Platen --- .../pipeline_semantic_stable_diffusion.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/diffusers/pipelines/semantic_stable_diffusion/pipeline_semantic_stable_diffusion.py b/src/diffusers/pipelines/semantic_stable_diffusion/pipeline_semantic_stable_diffusion.py index 8d304550974f..7c657c4af2e0 100644 --- a/src/diffusers/pipelines/semantic_stable_diffusion/pipeline_semantic_stable_diffusion.py +++ b/src/diffusers/pipelines/semantic_stable_diffusion/pipeline_semantic_stable_diffusion.py @@ -284,7 +284,7 @@ def __call__( num_inference_steps: int = 50, guidance_scale: float = 7.5, negative_prompt: Optional[Union[str, List[str]]] = None, - num_images_per_prompt: Optional[int] = 1, + num_images_per_prompt: int = 1, eta: float = 0.0, generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, latents: Optional[torch.FloatTensor] = None, From 49351129dea2d9366c5373e0697240240c6552b4 Mon Sep 17 00:00:00 2001 From: Manuel Brack Date: Thu, 16 Feb 2023 13:52:18 +0100 Subject: [PATCH 10/14] Update src/diffusers/pipelines/semantic_stable_diffusion/pipeline_semantic_stable_diffusion.py Co-authored-by: Patrick von Platen --- .../pipeline_semantic_stable_diffusion.py | 1 - 1 file changed, 1 deletion(-) diff --git a/src/diffusers/pipelines/semantic_stable_diffusion/pipeline_semantic_stable_diffusion.py b/src/diffusers/pipelines/semantic_stable_diffusion/pipeline_semantic_stable_diffusion.py index 7c657c4af2e0..84694b8e3dfa 100644 --- a/src/diffusers/pipelines/semantic_stable_diffusion/pipeline_semantic_stable_diffusion.py +++ b/src/diffusers/pipelines/semantic_stable_diffusion/pipeline_semantic_stable_diffusion.py @@ -303,7 +303,6 @@ def __call__( edit_mom_beta: Optional[float] = 0.4, edit_weights: Optional[List[float]] = None, sem_guidance: Optional[List[torch.FloatTensor]] = None, - **kwargs, ): r""" Function invoked when calling the pipeline for generation. From 624498559dfdc6054562df934797cd8be2951623 Mon Sep 17 00:00:00 2001 From: Manuel Brack Date: Thu, 16 Feb 2023 13:52:29 +0100 Subject: [PATCH 11/14] Update src/diffusers/pipelines/semantic_stable_diffusion/pipeline_semantic_stable_diffusion.py Co-authored-by: Patrick von Platen --- .../pipeline_semantic_stable_diffusion.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/diffusers/pipelines/semantic_stable_diffusion/pipeline_semantic_stable_diffusion.py b/src/diffusers/pipelines/semantic_stable_diffusion/pipeline_semantic_stable_diffusion.py index 84694b8e3dfa..ad9187fe09d1 100644 --- a/src/diffusers/pipelines/semantic_stable_diffusion/pipeline_semantic_stable_diffusion.py +++ b/src/diffusers/pipelines/semantic_stable_diffusion/pipeline_semantic_stable_diffusion.py @@ -291,7 +291,7 @@ def __call__( output_type: Optional[str] = "pil", return_dict: bool = True, callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None, - callback_steps: Optional[int] = 1, + callback_steps: int = 1, editing_prompt: Optional[Union[str, List[str]]] = None, editing_prompt_embeddings: Optional[torch.FloatTensor] = None, reverse_editing_direction: Optional[Union[bool, List[bool]]] = False, From 8ea9e85f643c221a20a7519934128c2262036e70 Mon Sep 17 00:00:00 2001 From: manuelbrack Date: Thu, 16 Feb 2023 16:07:36 +0100 Subject: [PATCH 12/14] Make compatible with half precision --- .../pipeline_semantic_stable_diffusion.py | 40 ++++++--- .../test_semantic_diffusion.py | 83 +++++++++++++++++++ 2 files changed, 111 insertions(+), 12 deletions(-) diff --git a/src/diffusers/pipelines/semantic_stable_diffusion/pipeline_semantic_stable_diffusion.py b/src/diffusers/pipelines/semantic_stable_diffusion/pipeline_semantic_stable_diffusion.py index ad9187fe09d1..6ce3e53f6d7b 100644 --- a/src/diffusers/pipelines/semantic_stable_diffusion/pipeline_semantic_stable_diffusion.py +++ b/src/diffusers/pipelines/semantic_stable_diffusion/pipeline_semantic_stable_diffusion.py @@ -293,7 +293,7 @@ def __call__( callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None, callback_steps: int = 1, editing_prompt: Optional[Union[str, List[str]]] = None, - editing_prompt_embeddings: Optional[torch.FloatTensor] = None, + editing_prompt_embeddings: Optional[torch.Tensor] = None, reverse_editing_direction: Optional[Union[bool, List[bool]]] = False, edit_guidance_scale: Optional[Union[float, List[float]]] = 5, edit_warmup_steps: Optional[Union[int, List[int]]] = 10, @@ -302,7 +302,7 @@ def __call__( edit_momentum_scale: Optional[float] = 0.1, edit_mom_beta: Optional[float] = 0.4, edit_weights: Optional[List[float]] = None, - sem_guidance: Optional[List[torch.FloatTensor]] = None, + sem_guidance: Optional[List[torch.Tensor]] = None, ): r""" Function invoked when calling the pipeline for generation. @@ -354,7 +354,7 @@ def __call__( The prompt or prompts to use for Semantic guidance. Semantic guidance is disabled by setting `editing_prompt = None`. Guidance direction of prompt should be specified via `reverse_editing_direction`. - editing_prompt_embeddings (`torch.FloatTensor`, *optional*): + editing_prompt_embeddings (`torch.Tensor>`, *optional*): Pre-computed embeddings to use for semantic guidance. Guidance direction of embedding should be specified via `reverse_editing_direction`. reverse_editing_direction (`bool` or `List[bool]`, *optional*, defaults to `False`): @@ -386,7 +386,7 @@ def __call__( Indicates how much each individual concept should influence the overall guidance. If no weights are provided all concepts are applied equally. `edit_mom_beta` is defined as `g_i` of equation 9 of [SEGA Paper](https://arxiv.org/pdf/2301.12247.pdf). - sem_guidance (`List[torch.FloatTensor]`, *optional*): + sem_guidance (`List[torch.Tensor]`, *optional*): List of pre-generated guidance vectors to be applied at generation. Length of the list has to correspond to `num_inference_steps`. @@ -591,10 +591,14 @@ def __call__( if enable_edit_guidance: concept_weights = torch.zeros( - (len(noise_pred_edit_concepts), noise_guidance.shape[0]), device=self.device + (len(noise_pred_edit_concepts), noise_guidance.shape[0]), + device=self.device, + dtype=noise_guidance.dtype, ) noise_guidance_edit = torch.zeros( - (len(noise_pred_edit_concepts), *noise_guidance.shape), device=self.device + (len(noise_pred_edit_concepts), *noise_guidance.shape), + device=self.device, + dtype=noise_guidance.dtype, ) # noise_guidance_edit = torch.zeros_like(noise_guidance) warmup_inds = [] @@ -644,12 +648,23 @@ def __call__( concept_weights[c, :] = tmp_weights noise_guidance_edit_tmp = noise_guidance_edit_tmp * edit_guidance_scale_c - tmp = torch.quantile( - torch.abs(noise_guidance_edit_tmp).flatten(start_dim=2), - edit_threshold_c, - dim=2, - keepdim=False, - ) + + # torch.quantile function expects float32 + if noise_guidance_edit_tmp.dtype == torch.float32: + tmp = torch.quantile( + torch.abs(noise_guidance_edit_tmp).flatten(start_dim=2), + edit_threshold_c, + dim=2, + keepdim=False, + ) + else: + tmp = torch.quantile( + torch.abs(noise_guidance_edit_tmp).flatten(start_dim=2).to(torch.float32), + edit_threshold_c, + dim=2, + keepdim=False, + ).to(noise_guidance_edit_tmp.dtype) + noise_guidance_edit_tmp = torch.where( torch.abs(noise_guidance_edit_tmp) >= tmp[:, :, None, None], noise_guidance_edit_tmp, @@ -692,6 +707,7 @@ def __call__( ) concept_weights = torch.nan_to_num(concept_weights) + noise_guidance_edit = torch.einsum("cb,cbijk->bijk", concept_weights, noise_guidance_edit) noise_guidance_edit = noise_guidance_edit + edit_momentum_scale * edit_momentum diff --git a/tests/pipelines/semantic_stable_diffusion/test_semantic_diffusion.py b/tests/pipelines/semantic_stable_diffusion/test_semantic_diffusion.py index c7e36610b438..804e3a8ccdce 100644 --- a/tests/pipelines/semantic_stable_diffusion/test_semantic_diffusion.py +++ b/tests/pipelines/semantic_stable_diffusion/test_semantic_diffusion.py @@ -516,3 +516,86 @@ def test_multi_cond_guidance(self): assert image.shape == (1, 512, 512, 3) assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2 + + def test_guidance_fp16(self): + torch_device = "cuda" + pipe = StableDiffusionPipeline.from_pretrained("runwayml/stable-diffusion-v1-5", torch_dtype=torch.float16) + pipe = pipe.to(torch_device) + pipe.set_progress_bar_config(disable=None) + + prompt = "a photo of a cat" + edit = { + "editing_prompt": ["sunglasses"], + "reverse_editing_direction": [False], + "edit_warmup_steps": 10, + "edit_guidance_scale": 6, + "edit_threshold": 0.95, + "edit_momentum_scale": 0.5, + "edit_mom_beta": 0.6, + } + + seed = 3 + guidance_scale = 7 + + # no sega enabled + generator = torch.Generator(torch_device) + generator.manual_seed(seed) + output = pipe( + [prompt], + generator=generator, + guidance_scale=guidance_scale, + num_inference_steps=50, + output_type="np", + width=512, + height=512, + ) + + image = output.images + image_slice = image[0, -3:, -3:, -1] + expected_slice = [ + 0.34887695, + 0.3876953, + 0.375, + 0.34423828, + 0.3581543, + 0.35717773, + 0.3383789, + 0.34570312, + 0.359375, + ] + + assert image.shape == (1, 512, 512, 3) + + assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2 + + # with sega enabled + # generator = torch.manual_seed(seed) + generator.manual_seed(seed) + output = pipe( + [prompt], + generator=generator, + guidance_scale=guidance_scale, + num_inference_steps=50, + output_type="np", + width=512, + height=512, + **edit, + ) + + image = output.images + image_slice = image[0, -3:, -3:, -1] + expected_slice = [ + 0.42285156, + 0.36914062, + 0.29077148, + 0.42041016, + 0.41918945, + 0.35498047, + 0.3618164, + 0.4423828, + 0.43115234, + ] + + assert image.shape == (1, 512, 512, 3) + + assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2 From 6ba73fb01b15de1d1fdb82b9dbd290b0099ff5c7 Mon Sep 17 00:00:00 2001 From: manuelbrack Date: Fri, 17 Feb 2023 09:35:12 +0100 Subject: [PATCH 13/14] Change deprecation warning to throw an exception --- .../pipeline_semantic_stable_diffusion.py | 48 +++++++------------ 1 file changed, 18 insertions(+), 30 deletions(-) diff --git a/src/diffusers/pipelines/semantic_stable_diffusion/pipeline_semantic_stable_diffusion.py b/src/diffusers/pipelines/semantic_stable_diffusion/pipeline_semantic_stable_diffusion.py index 6ce3e53f6d7b..883033945eec 100644 --- a/src/diffusers/pipelines/semantic_stable_diffusion/pipeline_semantic_stable_diffusion.py +++ b/src/diffusers/pipelines/semantic_stable_diffusion/pipeline_semantic_stable_diffusion.py @@ -105,33 +105,25 @@ def __init__( requires_safety_checker: bool = True, ): super().__init__() - if hasattr(scheduler.config, "steps_offset") and scheduler.config.steps_offset != 1: deprecation_message = ( - f"The configuration file of this scheduler: {scheduler} is outdated. `steps_offset`" - f" should be set to 1 instead of {scheduler.config.steps_offset}. Please make sure " - "to update the config accordingly as leaving `steps_offset` might led to incorrect results" - " in future versions. If you have downloaded this checkpoint from the Hugging Face Hub," - " it would be very nice if you could open a Pull request for the `scheduler/scheduler_config.json`" - " file" + f"Unsupported scheduler configuration: {scheduler} is outdated. `steps_offset` should be set to 1" + f" instead of {scheduler.config.steps_offset}. Please make sure to update the config accordingly as" + " leaving `steps_offset` is no longer supported. If you have downloaded this checkpoint from the" + " Hugging Face Hub, it would be very nice if you could open a Pull request for the" + " `scheduler/scheduler_config.json` file" ) - deprecate("steps_offset!=1", "1.0.0", deprecation_message, standard_warn=False) - new_config = dict(scheduler.config) - new_config["steps_offset"] = 1 - scheduler._internal_dict = FrozenDict(new_config) + raise ValueError(deprecation_message) if hasattr(scheduler.config, "clip_sample") and scheduler.config.clip_sample is True: deprecation_message = ( - f"The configuration file of this scheduler: {scheduler} has not set the configuration `clip_sample`." + f"Unsupported scheduler configuration: {scheduler} has not set the configuration `clip_sample`." " `clip_sample` should be set to False in the configuration file. Please make sure to update the" - " config accordingly as not setting `clip_sample` in the config might lead to incorrect results in" - " future versions. If you have downloaded this checkpoint from the Hugging Face Hub, it would be very" - " nice if you could open a Pull request for the `scheduler/scheduler_config.json` file" + " config accordingly as not setting `clip_sample` in the config is no longer supported. If you have" + " downloaded this checkpoint from the Hugging Face Hub, it would be very nice if you could open a" + " Pull request for the `scheduler/scheduler_config.json` file" ) - deprecate("clip_sample not set", "1.0.0", deprecation_message, standard_warn=False) - new_config = dict(scheduler.config) - new_config["clip_sample"] = False - scheduler._internal_dict = FrozenDict(new_config) + raise ValueError(deprecation_message) if safety_checker is None and requires_safety_checker: logger.warning( @@ -155,20 +147,16 @@ def __init__( is_unet_sample_size_less_64 = hasattr(unet.config, "sample_size") and unet.config.sample_size < 64 if is_unet_version_less_0_9_0 and is_unet_sample_size_less_64: deprecation_message = ( - "The configuration file of the unet has set the default `sample_size` to smaller than" - " 64 which seems highly unlikely. If your checkpoint is a fine-tuned version of any of the" - " following: \n- CompVis/stable-diffusion-v1-4 \n- CompVis/stable-diffusion-v1-3 \n-" + " Unsupported unet configuration: The configuration of the unet has set the default `sample_size` to" + " smaller than 64 which seems highly unlikely. If your checkpoint is a fine-tuned version of any of" + " the following: \n- CompVis/stable-diffusion-v1-4 \n- CompVis/stable-diffusion-v1-3 \n-" " CompVis/stable-diffusion-v1-2 \n- CompVis/stable-diffusion-v1-1 \n- runwayml/stable-diffusion-v1-5" " \n- runwayml/stable-diffusion-inpainting \n you should change 'sample_size' to 64 in the" - " configuration file. Please make sure to update the config accordingly as leaving `sample_size=32`" - " in the config might lead to incorrect results in future versions. If you have downloaded this" - " checkpoint from the Hugging Face Hub, it would be very nice if you could open a Pull request for" - " the `unet/config.json` file" + " configuration file. Please make sure to update the config accordingly as leaving `sample_size=32` in" + " the config is no longer supported. If you have downloaded this checkpoint from the Hugging Face Hub," + " it would be very nice if you could open a Pull request for the `unet/config.json` file" ) - deprecate("sample_size<64", "1.0.0", deprecation_message, standard_warn=False) - new_config = dict(unet.config) - new_config["sample_size"] = 64 - unet._internal_dict = FrozenDict(new_config) + raise ValueError(deprecation_message) self.register_modules( vae=vae, From 1938d384c2ec3c79fc2e4ca0e01d0302d33aa178 Mon Sep 17 00:00:00 2001 From: Patrick von Platen Date: Fri, 17 Feb 2023 14:29:49 +0000 Subject: [PATCH 14/14] update --- docs/source/en/_toctree.yml | 4 +- .../semantic_stable_diffusion/__init__.py | 1 - .../pipeline_semantic_stable_diffusion.py | 43 +------------------ .../dummy_torch_and_transformers_objects.py | 3 +- .../test_semantic_diffusion.py | 2 +- 5 files changed, 6 insertions(+), 47 deletions(-) diff --git a/docs/source/en/_toctree.yml b/docs/source/en/_toctree.yml index 4d5d22b9b6d5..9280813fa51c 100644 --- a/docs/source/en/_toctree.yml +++ b/docs/source/en/_toctree.yml @@ -128,12 +128,12 @@ title: PNDM - local: api/pipelines/repaint title: RePaint - - local: api/pipelines/semantic_stable_diffusion - title: Semantic Guidance - local: api/pipelines/stable_diffusion_safe title: Safe Stable Diffusion - local: api/pipelines/score_sde_ve title: Score SDE VE + - local: api/pipelines/semantic_stable_diffusion + title: Semantic Guidance - sections: - local: api/pipelines/stable_diffusion/overview title: Overview diff --git a/src/diffusers/pipelines/semantic_stable_diffusion/__init__.py b/src/diffusers/pipelines/semantic_stable_diffusion/__init__.py index a7716f58e609..0e312c5e3013 100644 --- a/src/diffusers/pipelines/semantic_stable_diffusion/__init__.py +++ b/src/diffusers/pipelines/semantic_stable_diffusion/__init__.py @@ -3,7 +3,6 @@ from typing import List, Optional, Union import numpy as np - import PIL from PIL import Image diff --git a/src/diffusers/pipelines/semantic_stable_diffusion/pipeline_semantic_stable_diffusion.py b/src/diffusers/pipelines/semantic_stable_diffusion/pipeline_semantic_stable_diffusion.py index 883033945eec..a421a844c329 100644 --- a/src/diffusers/pipelines/semantic_stable_diffusion/pipeline_semantic_stable_diffusion.py +++ b/src/diffusers/pipelines/semantic_stable_diffusion/pipeline_semantic_stable_diffusion.py @@ -3,16 +3,13 @@ from typing import Callable, List, Optional, Union import torch - -from packaging import version from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer -from ...configuration_utils import FrozenDict from ...models import AutoencoderKL, UNet2DConditionModel from ...pipeline_utils import DiffusionPipeline from ...pipelines.stable_diffusion.safety_checker import StableDiffusionSafetyChecker from ...schedulers import KarrasDiffusionSchedulers -from ...utils import deprecate, logging, randn_tensor +from ...utils import logging, randn_tensor from . import SemanticStableDiffusionPipelineOutput @@ -105,25 +102,6 @@ def __init__( requires_safety_checker: bool = True, ): super().__init__() - if hasattr(scheduler.config, "steps_offset") and scheduler.config.steps_offset != 1: - deprecation_message = ( - f"Unsupported scheduler configuration: {scheduler} is outdated. `steps_offset` should be set to 1" - f" instead of {scheduler.config.steps_offset}. Please make sure to update the config accordingly as" - " leaving `steps_offset` is no longer supported. If you have downloaded this checkpoint from the" - " Hugging Face Hub, it would be very nice if you could open a Pull request for the" - " `scheduler/scheduler_config.json` file" - ) - raise ValueError(deprecation_message) - - if hasattr(scheduler.config, "clip_sample") and scheduler.config.clip_sample is True: - deprecation_message = ( - f"Unsupported scheduler configuration: {scheduler} has not set the configuration `clip_sample`." - " `clip_sample` should be set to False in the configuration file. Please make sure to update the" - " config accordingly as not setting `clip_sample` in the config is no longer supported. If you have" - " downloaded this checkpoint from the Hugging Face Hub, it would be very nice if you could open a" - " Pull request for the `scheduler/scheduler_config.json` file" - ) - raise ValueError(deprecation_message) if safety_checker is None and requires_safety_checker: logger.warning( @@ -141,23 +119,6 @@ def __init__( " checker. If you do not want to use the safety checker, you can pass `'safety_checker=None'` instead." ) - is_unet_version_less_0_9_0 = hasattr(unet.config, "_diffusers_version") and version.parse( - version.parse(unet.config._diffusers_version).base_version - ) < version.parse("0.9.0.dev0") - is_unet_sample_size_less_64 = hasattr(unet.config, "sample_size") and unet.config.sample_size < 64 - if is_unet_version_less_0_9_0 and is_unet_sample_size_less_64: - deprecation_message = ( - " Unsupported unet configuration: The configuration of the unet has set the default `sample_size` to" - " smaller than 64 which seems highly unlikely. If your checkpoint is a fine-tuned version of any of" - " the following: \n- CompVis/stable-diffusion-v1-4 \n- CompVis/stable-diffusion-v1-3 \n-" - " CompVis/stable-diffusion-v1-2 \n- CompVis/stable-diffusion-v1-1 \n- runwayml/stable-diffusion-v1-5" - " \n- runwayml/stable-diffusion-inpainting \n you should change 'sample_size' to 64 in the" - " configuration file. Please make sure to update the config accordingly as leaving `sample_size=32` in" - " the config is no longer supported. If you have downloaded this checkpoint from the Hugging Face Hub," - " it would be very nice if you could open a Pull request for the `unet/config.json` file" - ) - raise ValueError(deprecation_message) - self.register_modules( vae=vae, text_encoder=text_encoder, @@ -175,7 +136,7 @@ def decode_latents(self, latents): latents = 1 / self.vae.config.scaling_factor * latents image = self.vae.decode(latents).sample image = (image / 2 + 0.5).clamp(0, 1) - # we always cast to float32 as this does not cause significant overhead and is compatible with bfloa16 + # we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16 image = image.cpu().permute(0, 2, 3, 1).float().numpy() return image diff --git a/src/diffusers/utils/dummy_torch_and_transformers_objects.py b/src/diffusers/utils/dummy_torch_and_transformers_objects.py index 0d91b33f0983..195cf7903947 100644 --- a/src/diffusers/utils/dummy_torch_and_transformers_objects.py +++ b/src/diffusers/utils/dummy_torch_and_transformers_objects.py @@ -77,7 +77,6 @@ def from_pretrained(cls, *args, **kwargs): requires_backends(cls, ["torch", "transformers"]) - class SemanticStableDiffusionPipeline(metaclass=DummyObject): _backends = ["torch", "transformers"] @@ -92,8 +91,8 @@ def from_config(cls, *args, **kwargs): def from_pretrained(cls, *args, **kwargs): requires_backends(cls, ["torch", "transformers"]) -class StableDiffusionAttendAndExcitePipeline(metaclass=DummyObject): +class StableDiffusionAttendAndExcitePipeline(metaclass=DummyObject): _backends = ["torch", "transformers"] def __init__(self, *args, **kwargs): diff --git a/tests/pipelines/semantic_stable_diffusion/test_semantic_diffusion.py b/tests/pipelines/semantic_stable_diffusion/test_semantic_diffusion.py index 804e3a8ccdce..329fdd0bfd7e 100644 --- a/tests/pipelines/semantic_stable_diffusion/test_semantic_diffusion.py +++ b/tests/pipelines/semantic_stable_diffusion/test_semantic_diffusion.py @@ -20,12 +20,12 @@ import numpy as np import torch +from transformers import CLIPTextConfig, CLIPTextModel, CLIPTokenizer from diffusers import AutoencoderKL, DDIMScheduler, LMSDiscreteScheduler, PNDMScheduler, UNet2DConditionModel from diffusers.pipelines.semantic_stable_diffusion import SemanticStableDiffusionPipeline as StableDiffusionPipeline from diffusers.utils import floats_tensor, nightly, torch_device from diffusers.utils.testing_utils import require_torch_gpu -from transformers import CLIPTextConfig, CLIPTextModel, CLIPTokenizer torch.backends.cuda.matmul.allow_tf32 = False