diff --git a/src/diffusers/models/unet_2d_blocks.py b/src/diffusers/models/unet_2d_blocks.py index f865b42eb9d5..115faf3ad151 100644 --- a/src/diffusers/models/unet_2d_blocks.py +++ b/src/diffusers/models/unet_2d_blocks.py @@ -808,7 +808,7 @@ def __init__( self.gradient_checkpointing = False def forward( - self, hidden_states, temb=None, encoder_hidden_states=None, attention_mask=None, cross_attention_kwargs=None + self, hidden_states, temb=None, encoder_hidden_states=None, attention_mask=None, down_block_res=None, cross_attention_kwargs=None ): # TODO(Patrick, William) - attention mask is not used output_states = () @@ -843,6 +843,8 @@ def custom_forward(*inputs): output_states += (hidden_states,) if self.downsamplers is not None: + if down_block_res is not None: + hidden_states += down_block_res for downsampler in self.downsamplers: hidden_states = downsampler(hidden_states) diff --git a/src/diffusers/models/unet_2d_condition.py b/src/diffusers/models/unet_2d_condition.py index 8cd3dcf42307..b054a840e236 100644 --- a/src/diffusers/models/unet_2d_condition.py +++ b/src/diffusers/models/unet_2d_condition.py @@ -576,23 +576,28 @@ def forward( # 2. pre-process sample = self.conv_in(sample) + is_controlnet = mid_block_additional_residual is not None and down_block_additional_residuals is not None + is_t2i = mid_block_additional_residual is None and down_block_additional_residuals is not None # 3. down down_block_res_samples = (sample,) for downsample_block in self.down_blocks: if hasattr(downsample_block, "has_cross_attention") and downsample_block.has_cross_attention: + # find out whether `is_t2i` depending on the shape of the residual connections + kwargs = {} if not is_t2i else {"down_block_res": down_block_additional_residuals.pop()} sample, res_samples = downsample_block( hidden_states=sample, temb=emb, encoder_hidden_states=encoder_hidden_states, attention_mask=attention_mask, cross_attention_kwargs=cross_attention_kwargs, + **kwargs, ) else: sample, res_samples = downsample_block(hidden_states=sample, temb=emb) down_block_res_samples += res_samples - if down_block_additional_residuals is not None: + if is_controlnet: new_down_block_res_samples = () for down_block_res_sample, down_block_additional_residual in zip( @@ -613,7 +618,7 @@ def forward( cross_attention_kwargs=cross_attention_kwargs, ) - if mid_block_additional_residual is not None: + if is_controlnet: sample = sample + mid_block_additional_residual # 5. up