Skip to content

Improve performance via batched-matmul and fused multiplies #7

@Birch-san

Description

@Birch-san

Many thanks for providing this reference implementation.

I tried integrating this into stable-diffusion / diffusers. A fix was required to make it work on Mac (PyTorch MPS backend):
Birch-san/diffusers@0437214

Knowing that computing attention via baddbmm()+bmm() can outperform einsum by 18%: I tried to rewrite the algorithm to use those.

I compared the speed of my optimized version, against the implementation in this repository.
this result is for "everything fits in one chunk" perf (i.e. chunk size = max token length). I was unable to compare chunked perf, because although I got chunking working in my version: I wasn't able to get it working in the version in this repository (got some unexpected-shape tensors returned).

compared to the implementation in this repository:
my optimized version achieves a 2.78x speedup in the time it took to generate a 512x512 image with stable-diffusion v2.1-base (i.e. 4096 vision tokens, 5 attention heads, batch size of 2 due to CFG).

here's my optimized implementation:
Birch-san/diffusers#1

batched matmuls require a 3D tensor, i.e. [batch * num_heads, tokens, channels_per_head].

code that currently integrates agains this repository's [batch, q_length, num_heads, qk_depth_per_head] format can migrate those tensors to the [batch * num_heads, q_length, channels_per_head] format favoured by my implementation like so:

query = query.transpose(1,2).flatten(end_dim=1)
key = key.transpose(1,2).flatten(end_dim=1)
value = value.transpose(1,2).flatten(end_dim=1)

the result that's returned, remains in [batch * num_heads, q_length, qk_depth_per_head] format, and can be restored to [batch, q_length, num_heads, qk_depth_per_head] format like so:

result.unflatten(0, (-1, attn.heads)).transpose(1,2)

I think a further speedup is possible too: by working out when chunking is not needed: we can compute whether unchunked attention would fit into memory, and prefer unchunked attention as a fast-path where possible. this will be useful in a Unet, which runs attention at various resolutions.

EDIT:
I have now added fast-paths for:

  • skipping kv-chunking when kv_chunk_size >= k_tokens
    • this turns the algorithm into "attention slicing"
  • skipping q-chunking when q_chunk_size >= q_tokens
  • skipping all chunking when the kv_chunk_size >= k_tokens and q_chunk_size >= q_tokens
  • skipping all chunking when the [email protected] matmul requires fewer bytes than a user-provided threshold

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions