From 072a5dd5ac51def0cb439c74508df3502c83edd1 Mon Sep 17 00:00:00 2001 From: William Berman Date: Mon, 10 Apr 2023 14:36:25 -0700 Subject: [PATCH 1/2] `AttentionProcessor.group_norm` num_channels should be `query_dim` The group_norm on the attention processor should really norm the number of channels in the query _not_ the inner dim. This wasn't caught before because the group_norm is only used by the added kv attention processors and the added kv attention processors are only used by the karlo models which are configured such that the inner dim is the same as the query dim. --- src/diffusers/models/attention_processor.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/diffusers/models/attention_processor.py b/src/diffusers/models/attention_processor.py index 3eeb132fe65e..e8999e23fa18 100644 --- a/src/diffusers/models/attention_processor.py +++ b/src/diffusers/models/attention_processor.py @@ -81,7 +81,7 @@ def __init__( self.added_kv_proj_dim = added_kv_proj_dim if norm_num_groups is not None: - self.group_norm = nn.GroupNorm(num_channels=inner_dim, num_groups=norm_num_groups, eps=1e-5, affine=True) + self.group_norm = nn.GroupNorm(num_channels=query_dim, num_groups=norm_num_groups, eps=1e-5, affine=True) else: self.group_norm = None From 33b4c025d68d254ed8825dee5b917a8f83885df5 Mon Sep 17 00:00:00 2001 From: William Berman Date: Mon, 10 Apr 2023 15:50:37 -0700 Subject: [PATCH 2/2] add_{k,v}_proj should be projecting to inner_dim --- src/diffusers/models/attention_processor.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/diffusers/models/attention_processor.py b/src/diffusers/models/attention_processor.py index e8999e23fa18..04ead2adcf6e 100644 --- a/src/diffusers/models/attention_processor.py +++ b/src/diffusers/models/attention_processor.py @@ -93,8 +93,8 @@ def __init__( self.to_v = nn.Linear(cross_attention_dim, inner_dim, bias=bias) if self.added_kv_proj_dim is not None: - self.add_k_proj = nn.Linear(added_kv_proj_dim, cross_attention_dim) - self.add_v_proj = nn.Linear(added_kv_proj_dim, cross_attention_dim) + self.add_k_proj = nn.Linear(added_kv_proj_dim, inner_dim) + self.add_v_proj = nn.Linear(added_kv_proj_dim, inner_dim) self.to_out = nn.ModuleList([]) self.to_out.append(nn.Linear(inner_dim, query_dim, bias=out_bias))