Skip to content

add only cross attention to simple attention blocks #3011

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Conversation

williamberman
Copy link
Contributor

@williamberman williamberman commented Apr 7, 2023

This PR allows the added kv attention processor blocks to skip their self attention and only do cross attention

We also pass the only_cross_attention argument through the "simple" unet blocks

@williamberman williamberman force-pushed the only_cross_attention_simple_attention_blocks branch from c7d7ad9 to 9a56e8f Compare April 7, 2023 20:04
@@ -321,6 +322,7 @@ def __init__(
attn_num_head_channels=attention_head_dim[-1],
resnet_groups=norm_num_groups,
resnet_time_scale_shift=resnet_time_scale_shift,
only_cross_attention=mid_block_only_cross_attention,
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@patrickvonplaten you had used only_cross_attention[-1] for this argument but I think it makes sense to use a separate config for the mid block. The other arguments for the mid block that re-use the config for the last encoder block make sense because they are dimensionality based and they have to match. But this constraint doesn't hold necessarily for the only_cross_attention flag

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actually with this commit, 7329ead , we can do a better default to the value of only_cross_attention when it's given as a single boolean

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ok for me

@HuggingFaceDocBuilderDev
Copy link

HuggingFaceDocBuilderDev commented Apr 7, 2023

The documentation is not available anymore as the PR was closed or merged.

@williamberman williamberman force-pushed the only_cross_attention_simple_attention_blocks branch from 9a56e8f to ad99a33 Compare April 8, 2023 23:53
Copy link
Contributor

@patrickvonplaten patrickvonplaten left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice - some tests would be good as well :-)

@williamberman williamberman force-pushed the only_cross_attention_simple_attention_blocks branch from ad99a33 to ea330a3 Compare April 10, 2023 21:43
@williamberman
Copy link
Contributor Author

I rebased on top of #3046 because was needed for tests

@williamberman williamberman force-pushed the only_cross_attention_simple_attention_blocks branch from ea330a3 to 6a43702 Compare April 10, 2023 22:55
@williamberman
Copy link
Contributor Author

@patrickvonplaten I'm not exactly sure what's best to test here. I added a sanity check test that the outputs using only_cross_attention vs not are different here 320d288

Could you take and let me know if that's sufficient before I merge?


only_cross_attn_out = attn(**forward_args)

self.assertTrue((only_cross_attn_out != self_and_cross_attn_out).all())
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sufficient for merging IMO.

Copy link
Member

@sayakpaul sayakpaul left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The test is sufficient for merging IMO.

allow mid_block_only_cross_attention to default to
`only_cross_attention` when `only_cross_attention` is given
as a single boolean
@williamberman williamberman force-pushed the only_cross_attention_simple_attention_blocks branch from 7329ead to b9ffce2 Compare April 11, 2023 17:34
@williamberman williamberman merged commit c6180a3 into huggingface:main Apr 11, 2023
w4ffl35 pushed a commit to w4ffl35/diffusers that referenced this pull request Apr 14, 2023
* add only cross attention to simple attention blocks

* add test for only_cross_attention re: @patrickvonplaten

* mid_block_only_cross_attention better default

allow mid_block_only_cross_attention to default to
`only_cross_attention` when `only_cross_attention` is given
as a single boolean
dg845 pushed a commit to dg845/diffusers that referenced this pull request May 6, 2023
* add only cross attention to simple attention blocks

* add test for only_cross_attention re: @patrickvonplaten

* mid_block_only_cross_attention better default

allow mid_block_only_cross_attention to default to
`only_cross_attention` when `only_cross_attention` is given
as a single boolean
yoonseokjin pushed a commit to yoonseokjin/diffusers that referenced this pull request Dec 25, 2023
* add only cross attention to simple attention blocks

* add test for only_cross_attention re: @patrickvonplaten

* mid_block_only_cross_attention better default

allow mid_block_only_cross_attention to default to
`only_cross_attention` when `only_cross_attention` is given
as a single boolean
AmericanPresidentJimmyCarter pushed a commit to AmericanPresidentJimmyCarter/diffusers that referenced this pull request Apr 26, 2024
* add only cross attention to simple attention blocks

* add test for only_cross_attention re: @patrickvonplaten

* mid_block_only_cross_attention better default

allow mid_block_only_cross_attention to default to
`only_cross_attention` when `only_cross_attention` is given
as a single boolean
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants