Skip to content

Commit c09c4f3

Browse files
rupertmenneerpatrickvonplatenwilliamberman
authored
Adding 'strength' parameter to StableDiffusionInpaintingPipeline (#3424)
* Added explanation of 'strength' parameter * Added get_timesteps function which relies on new strength parameter * Added `strength` parameter which defaults to 1. * Swapped ordering so `noise_timestep` can be calculated before masking the image this is required when you aren't applying 100% noise to the masked region, e.g. strength < 1. * Added strength to check_inputs, throws error if out of range * Changed `prepare_latents` to initialise latents w.r.t strength inspired from the stable diffusion img2img pipeline, init latents are initialised by converting the init image into a VAE latent and adding noise (based upon the strength parameter passed in), e.g. random when strength = 1, or the init image at strength = 0. * WIP: Added a unit test for the new strength parameter in the StableDiffusionInpaintingPipeline still need to add correct regression values * Created a is_strength_max to initialise from pure random noise * Updated unit tests w.r.t new strength parameter + fixed new strength unit test * renamed parameter to avoid confusion with variable of same name * Updated regression values for new strength test - now passes * removed 'copied from' comment as this method is now different and divergent from the cpy * Update src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py Co-authored-by: Patrick von Platen <[email protected]> * Ensure backwards compatibility for prepare_mask_and_masked_image created a return_image boolean and initialised to false * Ensure backwards compatibility for prepare_latents * Fixed copy check typo * Fixes w.r.t backward compibility changes * make style * keep function argument ordering same for backwards compatibility in callees with copied from statements * make fix-copies --------- Co-authored-by: Patrick von Platen <[email protected]> Co-authored-by: William Berman <[email protected]>
1 parent 6070b32 commit c09c4f3

File tree

3 files changed

+211
-37
lines changed

3 files changed

+211
-37
lines changed

src/diffusers/pipelines/controlnet/pipeline_controlnet_inpaint.py

+44-3
Original file line numberDiff line numberDiff line change
@@ -99,7 +99,7 @@
9999

100100

101101
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_inpaint.prepare_mask_and_masked_image
102-
def prepare_mask_and_masked_image(image, mask, height, width):
102+
def prepare_mask_and_masked_image(image, mask, height, width, return_image=False):
103103
"""
104104
Prepares a pair (image, mask) to be consumed by the Stable Diffusion pipeline. This means that those inputs will be
105105
converted to ``torch.Tensor`` with shapes ``batch x channels x height x width`` where ``channels`` is ``3`` for the
@@ -209,6 +209,10 @@ def prepare_mask_and_masked_image(image, mask, height, width):
209209

210210
masked_image = image * (mask < 0.5)
211211

212+
# n.b. ensure backwards compatibility as old function does not return image
213+
if return_image:
214+
return mask, masked_image, image
215+
212216
return mask, masked_image
213217

214218

@@ -795,21 +799,58 @@ def prepare_control_image(
795799
return image
796800

797801
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_inpaint.StableDiffusionInpaintPipeline.prepare_latents
798-
def prepare_latents(self, batch_size, num_channels_latents, height, width, dtype, device, generator, latents=None):
802+
def prepare_latents(
803+
self,
804+
batch_size,
805+
num_channels_latents,
806+
height,
807+
width,
808+
dtype,
809+
device,
810+
generator,
811+
latents=None,
812+
image=None,
813+
timestep=None,
814+
is_strength_max=True,
815+
):
799816
shape = (batch_size, num_channels_latents, height // self.vae_scale_factor, width // self.vae_scale_factor)
800817
if isinstance(generator, list) and len(generator) != batch_size:
801818
raise ValueError(
802819
f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
803820
f" size of {batch_size}. Make sure the batch size matches the length of the generators."
804821
)
805822

823+
if (image is None or timestep is None) and not is_strength_max:
824+
raise ValueError(
825+
"Since strength < 1. initial latents are to be initialised as a combination of Image + Noise."
826+
"However, either the image or the noise timestep has not been provided."
827+
)
828+
806829
if latents is None:
807-
latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
830+
noise = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
831+
if is_strength_max:
832+
# if strength is 100% then simply initialise the latents to noise
833+
latents = noise
834+
else:
835+
# otherwise initialise latents as init image + noise
836+
image = image.to(device=device, dtype=dtype)
837+
if isinstance(generator, list):
838+
image_latents = [
839+
self.vae.encode(image[i : i + 1]).latent_dist.sample(generator=generator[i])
840+
for i in range(batch_size)
841+
]
842+
else:
843+
image_latents = self.vae.encode(image).latent_dist.sample(generator=generator)
844+
845+
image_latents = self.vae.config.scaling_factor * image_latents
846+
847+
latents = self.scheduler.add_noise(image_latents, noise, timestep)
808848
else:
809849
latents = latents.to(device)
810850

811851
# scale the initial noise by the standard deviation required by the scheduler
812852
latents = latents * self.scheduler.init_noise_sigma
853+
813854
return latents
814855

815856
def _default_height_width(self, height, width, image):

src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py

+83-10
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,7 @@
3636
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
3737

3838

39-
def prepare_mask_and_masked_image(image, mask, height, width):
39+
def prepare_mask_and_masked_image(image, mask, height, width, return_image: bool = False):
4040
"""
4141
Prepares a pair (image, mask) to be consumed by the Stable Diffusion pipeline. This means that those inputs will be
4242
converted to ``torch.Tensor`` with shapes ``batch x channels x height x width`` where ``channels`` is ``3`` for the
@@ -146,6 +146,10 @@ def prepare_mask_and_masked_image(image, mask, height, width):
146146

147147
masked_image = image * (mask < 0.5)
148148

149+
# n.b. ensure backwards compatibility as old function does not return image
150+
if return_image:
151+
return mask, masked_image, image
152+
149153
return mask, masked_image
150154

151155

@@ -552,17 +556,20 @@ def decode_latents(self, latents):
552556
image = image.cpu().permute(0, 2, 3, 1).float().numpy()
553557
return image
554558

555-
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.check_inputs
556559
def check_inputs(
557560
self,
558561
prompt,
559562
height,
560563
width,
564+
strength,
561565
callback_steps,
562566
negative_prompt=None,
563567
prompt_embeds=None,
564568
negative_prompt_embeds=None,
565569
):
570+
if strength < 0 or strength > 1:
571+
raise ValueError(f"The value of strength should in [0.0, 1.0] but is {strength}")
572+
566573
if height % 8 != 0 or width % 8 != 0:
567574
raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.")
568575

@@ -600,22 +607,58 @@ def check_inputs(
600607
f" {negative_prompt_embeds.shape}."
601608
)
602609

603-
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_latents
604-
def prepare_latents(self, batch_size, num_channels_latents, height, width, dtype, device, generator, latents=None):
610+
def prepare_latents(
611+
self,
612+
batch_size,
613+
num_channels_latents,
614+
height,
615+
width,
616+
dtype,
617+
device,
618+
generator,
619+
latents=None,
620+
image=None,
621+
timestep=None,
622+
is_strength_max=True,
623+
):
605624
shape = (batch_size, num_channels_latents, height // self.vae_scale_factor, width // self.vae_scale_factor)
606625
if isinstance(generator, list) and len(generator) != batch_size:
607626
raise ValueError(
608627
f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
609628
f" size of {batch_size}. Make sure the batch size matches the length of the generators."
610629
)
611630

631+
if (image is None or timestep is None) and not is_strength_max:
632+
raise ValueError(
633+
"Since strength < 1. initial latents are to be initialised as a combination of Image + Noise."
634+
"However, either the image or the noise timestep has not been provided."
635+
)
636+
612637
if latents is None:
613-
latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
638+
noise = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
639+
if is_strength_max:
640+
# if strength is 100% then simply initialise the latents to noise
641+
latents = noise
642+
else:
643+
# otherwise initialise latents as init image + noise
644+
image = image.to(device=device, dtype=dtype)
645+
if isinstance(generator, list):
646+
image_latents = [
647+
self.vae.encode(image[i : i + 1]).latent_dist.sample(generator=generator[i])
648+
for i in range(batch_size)
649+
]
650+
else:
651+
image_latents = self.vae.encode(image).latent_dist.sample(generator=generator)
652+
653+
image_latents = self.vae.config.scaling_factor * image_latents
654+
655+
latents = self.scheduler.add_noise(image_latents, noise, timestep)
614656
else:
615657
latents = latents.to(device)
616658

617659
# scale the initial noise by the standard deviation required by the scheduler
618660
latents = latents * self.scheduler.init_noise_sigma
661+
619662
return latents
620663

621664
def prepare_mask_latents(
@@ -669,6 +712,16 @@ def prepare_mask_latents(
669712
masked_image_latents = masked_image_latents.to(device=device, dtype=dtype)
670713
return mask, masked_image_latents
671714

715+
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.StableDiffusionImg2ImgPipeline.get_timesteps
716+
def get_timesteps(self, num_inference_steps, strength, device):
717+
# get the original timestep using init_timestep
718+
init_timestep = min(int(num_inference_steps * strength), num_inference_steps)
719+
720+
t_start = max(num_inference_steps - init_timestep, 0)
721+
timesteps = self.scheduler.timesteps[t_start * self.scheduler.order :]
722+
723+
return timesteps, num_inference_steps - t_start
724+
672725
@torch.no_grad()
673726
def __call__(
674727
self,
@@ -677,6 +730,7 @@ def __call__(
677730
mask_image: Union[torch.FloatTensor, PIL.Image.Image] = None,
678731
height: Optional[int] = None,
679732
width: Optional[int] = None,
733+
strength: float = 1.0,
680734
num_inference_steps: int = 50,
681735
guidance_scale: float = 7.5,
682736
negative_prompt: Optional[Union[str, List[str]]] = None,
@@ -710,6 +764,13 @@ def __call__(
710764
The height in pixels of the generated image.
711765
width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
712766
The width in pixels of the generated image.
767+
strength (`float`, *optional*, defaults to 1.):
768+
Conceptually, indicates how much to transform the masked portion of the reference `image`. Must be
769+
between 0 and 1. `image` will be used as a starting point, adding more noise to it the larger the
770+
`strength`. The number of denoising steps depends on the amount of noise initially added. When
771+
`strength` is 1, added noise will be maximum and the denoising process will run for the full number of
772+
iterations specified in `num_inference_steps`. A value of 1, therefore, essentially ignores the masked
773+
portion of the reference `image`.
713774
num_inference_steps (`int`, *optional*, defaults to 50):
714775
The number of denoising steps. More denoising steps usually lead to a higher quality image at the
715776
expense of slower inference.
@@ -802,6 +863,7 @@ def __call__(
802863
prompt,
803864
height,
804865
width,
866+
strength,
805867
callback_steps,
806868
negative_prompt,
807869
prompt_embeds,
@@ -833,12 +895,20 @@ def __call__(
833895
negative_prompt_embeds=negative_prompt_embeds,
834896
)
835897

836-
# 4. Preprocess mask and image - resizes image and mask w.r.t height and width
837-
mask, masked_image = prepare_mask_and_masked_image(image, mask_image, height, width)
838-
839-
# 5. set timesteps
898+
# 4. set timesteps
840899
self.scheduler.set_timesteps(num_inference_steps, device=device)
841-
timesteps = self.scheduler.timesteps
900+
timesteps, num_inference_steps = self.get_timesteps(
901+
num_inference_steps=num_inference_steps, strength=strength, device=device
902+
)
903+
# at which timestep to set the initial noise (n.b. 50% if strength is 0.5)
904+
latent_timestep = timesteps[:1].repeat(batch_size * num_images_per_prompt)
905+
# create a boolean to check if the strength is set to 1. if so then initialise the latents with pure noise
906+
is_strength_max = strength == 1.0
907+
908+
# 5. Preprocess mask and image
909+
mask, masked_image, init_image = prepare_mask_and_masked_image(
910+
image, mask_image, height, width, return_image=True
911+
)
842912

843913
# 6. Prepare latent variables
844914
num_channels_latents = self.vae.config.latent_channels
@@ -851,6 +921,9 @@ def __call__(
851921
device,
852922
generator,
853923
latents,
924+
image=init_image,
925+
timestep=latent_timestep,
926+
is_strength_max=is_strength_max,
854927
)
855928

856929
# 7. Prepare mask latent variables

0 commit comments

Comments
 (0)