Skip to content

[Don't merge] T2I - Design proposition #2708

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 2 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion src/diffusers/models/unet_2d_blocks.py
Original file line number Diff line number Diff line change
Expand Up @@ -808,7 +808,7 @@ def __init__(
self.gradient_checkpointing = False

def forward(
self, hidden_states, temb=None, encoder_hidden_states=None, attention_mask=None, cross_attention_kwargs=None
self, hidden_states, temb=None, encoder_hidden_states=None, attention_mask=None, down_block_res=None, cross_attention_kwargs=None
):
# TODO(Patrick, William) - attention mask is not used
output_states = ()
Expand Down Expand Up @@ -843,6 +843,8 @@ def custom_forward(*inputs):
output_states += (hidden_states,)

if self.downsamplers is not None:
if down_block_res is not None:
hidden_states += down_block_res
for downsampler in self.downsamplers:
hidden_states = downsampler(hidden_states)

Expand Down
9 changes: 7 additions & 2 deletions src/diffusers/models/unet_2d_condition.py
Original file line number Diff line number Diff line change
Expand Up @@ -576,23 +576,28 @@ def forward(
# 2. pre-process
sample = self.conv_in(sample)

is_controlnet = mid_block_additional_residual is not None and down_block_additional_residuals is not None
is_t2i = mid_block_additional_residual is None and down_block_additional_residuals is not None
# 3. down
down_block_res_samples = (sample,)
for downsample_block in self.down_blocks:
if hasattr(downsample_block, "has_cross_attention") and downsample_block.has_cross_attention:
# find out whether `is_t2i` depending on the shape of the residual connections
kwargs = {} if not is_t2i else {"down_block_res": down_block_additional_residuals.pop()}
sample, res_samples = downsample_block(
hidden_states=sample,
temb=emb,
encoder_hidden_states=encoder_hidden_states,
attention_mask=attention_mask,
cross_attention_kwargs=cross_attention_kwargs,
**kwargs,
)
else:
sample, res_samples = downsample_block(hidden_states=sample, temb=emb)

down_block_res_samples += res_samples

if down_block_additional_residuals is not None:
if is_controlnet:
new_down_block_res_samples = ()

for down_block_res_sample, down_block_additional_residual in zip(
Expand All @@ -613,7 +618,7 @@ def forward(
cross_attention_kwargs=cross_attention_kwargs,
)

if mid_block_additional_residual is not None:
if is_controlnet:
sample = sample + mid_block_additional_residual

# 5. up
Expand Down