Skip to content

Commit cc1c4de

Browse files
authored
Add to support Guess Mode for StableDiffusionControlnetPipleline (huggingface#2998)
* add guess mode (WIP) * fix uncond/cond order * support guidance_scale=1.0 and batch != 1 * remove magic coeff * add docstring * add intergration test * add document to controlnet.mdx * made the comments a bit more explanatory * fix table
1 parent 6ea2d26 commit cc1c4de

File tree

2 files changed

+47
-6
lines changed

2 files changed

+47
-6
lines changed

models/controlnet.py

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -456,6 +456,7 @@ def forward(
456456
timestep_cond: Optional[torch.Tensor] = None,
457457
attention_mask: Optional[torch.Tensor] = None,
458458
cross_attention_kwargs: Optional[Dict[str, Any]] = None,
459+
guess_mode: bool = False,
459460
return_dict: bool = True,
460461
) -> Union[ControlNetOutput, Tuple]:
461462
# check channel order
@@ -556,8 +557,14 @@ def forward(
556557
mid_block_res_sample = self.controlnet_mid_block(sample)
557558

558559
# 6. scaling
559-
down_block_res_samples = [sample * conditioning_scale for sample in down_block_res_samples]
560-
mid_block_res_sample *= conditioning_scale
560+
if guess_mode:
561+
scales = torch.logspace(-1, 0, len(down_block_res_samples) + 1) # 0.1 to 1.0
562+
scales *= conditioning_scale
563+
down_block_res_samples = [sample * scale for sample, scale in zip(down_block_res_samples, scales)]
564+
mid_block_res_sample *= scales[-1] # last one
565+
else:
566+
down_block_res_samples = [sample * conditioning_scale for sample in down_block_res_samples]
567+
mid_block_res_sample *= conditioning_scale
561568

562569
if not return_dict:
563570
return (down_block_res_samples, mid_block_res_sample)

pipelines/stable_diffusion/pipeline_stable_diffusion_controlnet.py

Lines changed: 38 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -118,6 +118,7 @@ def forward(
118118
timestep_cond: Optional[torch.Tensor] = None,
119119
attention_mask: Optional[torch.Tensor] = None,
120120
cross_attention_kwargs: Optional[Dict[str, Any]] = None,
121+
guess_mode: bool = False,
121122
return_dict: bool = True,
122123
) -> Union[ControlNetOutput, Tuple]:
123124
for i, (image, scale, controlnet) in enumerate(zip(controlnet_cond, conditioning_scale, self.nets)):
@@ -131,6 +132,7 @@ def forward(
131132
timestep_cond,
132133
attention_mask,
133134
cross_attention_kwargs,
135+
guess_mode,
134136
return_dict,
135137
)
136138

@@ -627,7 +629,16 @@ def check_image(self, image, prompt, prompt_embeds):
627629
)
628630

629631
def prepare_image(
630-
self, image, width, height, batch_size, num_images_per_prompt, device, dtype, do_classifier_free_guidance
632+
self,
633+
image,
634+
width,
635+
height,
636+
batch_size,
637+
num_images_per_prompt,
638+
device,
639+
dtype,
640+
do_classifier_free_guidance,
641+
guess_mode,
631642
):
632643
if not isinstance(image, torch.Tensor):
633644
if isinstance(image, PIL.Image.Image):
@@ -664,7 +675,7 @@ def prepare_image(
664675

665676
image = image.to(device=device, dtype=dtype)
666677

667-
if do_classifier_free_guidance:
678+
if do_classifier_free_guidance and not guess_mode:
668679
image = torch.cat([image] * 2)
669680

670681
return image
@@ -747,6 +758,7 @@ def __call__(
747758
callback_steps: int = 1,
748759
cross_attention_kwargs: Optional[Dict[str, Any]] = None,
749760
controlnet_conditioning_scale: Union[float, List[float]] = 1.0,
761+
guess_mode: bool = False,
750762
):
751763
r"""
752764
Function invoked when calling the pipeline for generation.
@@ -819,6 +831,10 @@ def __call__(
819831
The outputs of the controlnet are multiplied by `controlnet_conditioning_scale` before they are added
820832
to the residual in the original unet. If multiple ControlNets are specified in init, you can set the
821833
corresponding scale as a list.
834+
guess_mode (`bool`, *optional*, defaults to `False`):
835+
In this mode, the ControlNet encoder will try best to recognize the content of the input image even if
836+
you remove all prompts. The `guidance_scale` between 3.0 and 5.0 is recommended.
837+
822838
Examples:
823839
824840
Returns:
@@ -883,6 +899,7 @@ def __call__(
883899
device=device,
884900
dtype=self.controlnet.dtype,
885901
do_classifier_free_guidance=do_classifier_free_guidance,
902+
guess_mode=guess_mode,
886903
)
887904
elif isinstance(self.controlnet, MultiControlNetModel):
888905
images = []
@@ -897,6 +914,7 @@ def __call__(
897914
device=device,
898915
dtype=self.controlnet.dtype,
899916
do_classifier_free_guidance=do_classifier_free_guidance,
917+
guess_mode=guess_mode,
900918
)
901919

902920
images.append(image_)
@@ -934,15 +952,31 @@ def __call__(
934952
latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
935953

936954
# controlnet(s) inference
955+
if guess_mode and do_classifier_free_guidance:
956+
# Infer ControlNet only for the conditional batch.
957+
controlnet_latent_model_input = latents
958+
controlnet_prompt_embeds = prompt_embeds.chunk(2)[1]
959+
else:
960+
controlnet_latent_model_input = latent_model_input
961+
controlnet_prompt_embeds = prompt_embeds
962+
937963
down_block_res_samples, mid_block_res_sample = self.controlnet(
938-
latent_model_input,
964+
controlnet_latent_model_input,
939965
t,
940-
encoder_hidden_states=prompt_embeds,
966+
encoder_hidden_states=controlnet_prompt_embeds,
941967
controlnet_cond=image,
942968
conditioning_scale=controlnet_conditioning_scale,
969+
guess_mode=guess_mode,
943970
return_dict=False,
944971
)
945972

973+
if guess_mode and do_classifier_free_guidance:
974+
# Infered ControlNet only for the conditional batch.
975+
# To apply the output of ControlNet to both the unconditional and conditional batches,
976+
# add 0 to the unconditional batch to keep it unchanged.
977+
down_block_res_samples = [torch.cat([torch.zeros_like(d), d]) for d in down_block_res_samples]
978+
mid_block_res_sample = torch.cat([torch.zeros_like(mid_block_res_sample), mid_block_res_sample])
979+
946980
# predict the noise residual
947981
noise_pred = self.unet(
948982
latent_model_input,

0 commit comments

Comments
 (0)