@@ -118,6 +118,7 @@ def forward(
118
118
timestep_cond : Optional [torch .Tensor ] = None ,
119
119
attention_mask : Optional [torch .Tensor ] = None ,
120
120
cross_attention_kwargs : Optional [Dict [str , Any ]] = None ,
121
+ guess_mode : bool = False ,
121
122
return_dict : bool = True ,
122
123
) -> Union [ControlNetOutput , Tuple ]:
123
124
for i , (image , scale , controlnet ) in enumerate (zip (controlnet_cond , conditioning_scale , self .nets )):
@@ -131,6 +132,7 @@ def forward(
131
132
timestep_cond ,
132
133
attention_mask ,
133
134
cross_attention_kwargs ,
135
+ guess_mode ,
134
136
return_dict ,
135
137
)
136
138
@@ -627,7 +629,16 @@ def check_image(self, image, prompt, prompt_embeds):
627
629
)
628
630
629
631
def prepare_image (
630
- self , image , width , height , batch_size , num_images_per_prompt , device , dtype , do_classifier_free_guidance
632
+ self ,
633
+ image ,
634
+ width ,
635
+ height ,
636
+ batch_size ,
637
+ num_images_per_prompt ,
638
+ device ,
639
+ dtype ,
640
+ do_classifier_free_guidance ,
641
+ guess_mode ,
631
642
):
632
643
if not isinstance (image , torch .Tensor ):
633
644
if isinstance (image , PIL .Image .Image ):
@@ -664,7 +675,7 @@ def prepare_image(
664
675
665
676
image = image .to (device = device , dtype = dtype )
666
677
667
- if do_classifier_free_guidance :
678
+ if do_classifier_free_guidance and not guess_mode :
668
679
image = torch .cat ([image ] * 2 )
669
680
670
681
return image
@@ -747,6 +758,7 @@ def __call__(
747
758
callback_steps : int = 1 ,
748
759
cross_attention_kwargs : Optional [Dict [str , Any ]] = None ,
749
760
controlnet_conditioning_scale : Union [float , List [float ]] = 1.0 ,
761
+ guess_mode : bool = False ,
750
762
):
751
763
r"""
752
764
Function invoked when calling the pipeline for generation.
@@ -819,6 +831,10 @@ def __call__(
819
831
The outputs of the controlnet are multiplied by `controlnet_conditioning_scale` before they are added
820
832
to the residual in the original unet. If multiple ControlNets are specified in init, you can set the
821
833
corresponding scale as a list.
834
+ guess_mode (`bool`, *optional*, defaults to `False`):
835
+ In this mode, the ControlNet encoder will try best to recognize the content of the input image even if
836
+ you remove all prompts. The `guidance_scale` between 3.0 and 5.0 is recommended.
837
+
822
838
Examples:
823
839
824
840
Returns:
@@ -883,6 +899,7 @@ def __call__(
883
899
device = device ,
884
900
dtype = self .controlnet .dtype ,
885
901
do_classifier_free_guidance = do_classifier_free_guidance ,
902
+ guess_mode = guess_mode ,
886
903
)
887
904
elif isinstance (self .controlnet , MultiControlNetModel ):
888
905
images = []
@@ -897,6 +914,7 @@ def __call__(
897
914
device = device ,
898
915
dtype = self .controlnet .dtype ,
899
916
do_classifier_free_guidance = do_classifier_free_guidance ,
917
+ guess_mode = guess_mode ,
900
918
)
901
919
902
920
images .append (image_ )
@@ -934,15 +952,31 @@ def __call__(
934
952
latent_model_input = self .scheduler .scale_model_input (latent_model_input , t )
935
953
936
954
# controlnet(s) inference
955
+ if guess_mode and do_classifier_free_guidance :
956
+ # Infer ControlNet only for the conditional batch.
957
+ controlnet_latent_model_input = latents
958
+ controlnet_prompt_embeds = prompt_embeds .chunk (2 )[1 ]
959
+ else :
960
+ controlnet_latent_model_input = latent_model_input
961
+ controlnet_prompt_embeds = prompt_embeds
962
+
937
963
down_block_res_samples , mid_block_res_sample = self .controlnet (
938
- latent_model_input ,
964
+ controlnet_latent_model_input ,
939
965
t ,
940
- encoder_hidden_states = prompt_embeds ,
966
+ encoder_hidden_states = controlnet_prompt_embeds ,
941
967
controlnet_cond = image ,
942
968
conditioning_scale = controlnet_conditioning_scale ,
969
+ guess_mode = guess_mode ,
943
970
return_dict = False ,
944
971
)
945
972
973
+ if guess_mode and do_classifier_free_guidance :
974
+ # Infered ControlNet only for the conditional batch.
975
+ # To apply the output of ControlNet to both the unconditional and conditional batches,
976
+ # add 0 to the unconditional batch to keep it unchanged.
977
+ down_block_res_samples = [torch .cat ([torch .zeros_like (d ), d ]) for d in down_block_res_samples ]
978
+ mid_block_res_sample = torch .cat ([torch .zeros_like (mid_block_res_sample ), mid_block_res_sample ])
979
+
946
980
# predict the noise residual
947
981
noise_pred = self .unet (
948
982
latent_model_input ,
0 commit comments