-
Notifications
You must be signed in to change notification settings - Fork 6k
add only cross attention to simple attention blocks #3011
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
add only cross attention to simple attention blocks #3011
Conversation
c7d7ad9
to
9a56e8f
Compare
@@ -321,6 +322,7 @@ def __init__( | |||
attn_num_head_channels=attention_head_dim[-1], | |||
resnet_groups=norm_num_groups, | |||
resnet_time_scale_shift=resnet_time_scale_shift, | |||
only_cross_attention=mid_block_only_cross_attention, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@patrickvonplaten you had used only_cross_attention[-1]
for this argument but I think it makes sense to use a separate config for the mid block. The other arguments for the mid block that re-use the config for the last encoder block make sense because they are dimensionality based and they have to match. But this constraint doesn't hold necessarily for the only_cross_attention
flag
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actually with this commit, 7329ead , we can do a better default to the value of only_cross_attention
when it's given as a single boolean
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ok for me
The documentation is not available anymore as the PR was closed or merged. |
9a56e8f
to
ad99a33
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nice - some tests would be good as well :-)
ad99a33
to
ea330a3
Compare
I rebased on top of #3046 because was needed for tests |
ea330a3
to
6a43702
Compare
@patrickvonplaten I'm not exactly sure what's best to test here. I added a sanity check test that the outputs using only_cross_attention vs not are different here 320d288 Could you take and let me know if that's sufficient before I merge? |
|
||
only_cross_attn_out = attn(**forward_args) | ||
|
||
self.assertTrue((only_cross_attn_out != self_and_cross_attn_out).all()) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sufficient for merging IMO.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The test is sufficient for merging IMO.
allow mid_block_only_cross_attention to default to `only_cross_attention` when `only_cross_attention` is given as a single boolean
7329ead
to
b9ffce2
Compare
* add only cross attention to simple attention blocks * add test for only_cross_attention re: @patrickvonplaten * mid_block_only_cross_attention better default allow mid_block_only_cross_attention to default to `only_cross_attention` when `only_cross_attention` is given as a single boolean
* add only cross attention to simple attention blocks * add test for only_cross_attention re: @patrickvonplaten * mid_block_only_cross_attention better default allow mid_block_only_cross_attention to default to `only_cross_attention` when `only_cross_attention` is given as a single boolean
* add only cross attention to simple attention blocks * add test for only_cross_attention re: @patrickvonplaten * mid_block_only_cross_attention better default allow mid_block_only_cross_attention to default to `only_cross_attention` when `only_cross_attention` is given as a single boolean
* add only cross attention to simple attention blocks * add test for only_cross_attention re: @patrickvonplaten * mid_block_only_cross_attention better default allow mid_block_only_cross_attention to default to `only_cross_attention` when `only_cross_attention` is given as a single boolean
This PR allows the added kv attention processor blocks to skip their self attention and only do cross attention
We also pass the only_cross_attention argument through the "simple" unet blocks