Skip to content

Commit 7b294e4

Browse files
committed
Revert "Fix slow tests (huggingface#689)"
This reverts commit b2cfc7a.
1 parent 14b9754 commit 7b294e4

File tree

1 file changed

+7
-2
lines changed

1 file changed

+7
-2
lines changed

src/diffusers/models/attention.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -274,8 +274,13 @@ def forward(self, hidden_states, context=None, mask=None):
274274
return self.to_out(hidden_states)
275275

276276
def _attention(self, query, key, value):
277-
# TODO: use baddbmm for better performance
278-
attention_scores = torch.matmul(query, key.transpose(-1, -2)) * self.scale
277+
attention_scores = torch.baddbmm(
278+
torch.empty(query.shape[0], query.shape[1], key.shape[1], dtype=query.dtype, device=query.device),
279+
query,
280+
key.transpose(-1, -2),
281+
beta=0,
282+
alpha=self.scale,
283+
)
279284
attention_probs = attention_scores.softmax(dim=-1)
280285
# compute attention output
281286
hidden_states = torch.matmul(attention_probs, value)

0 commit comments

Comments
 (0)