Skip to content

[Feat]: Splitting batched attention #1110

@dxqb

Description

@dxqb

Describe your use-case.

Torch SDPA is much faster if flash attention can be used - but flash attention doesn't support attention masks. All recent models use attention masking for variable caption lengths.

This means that you benefit from flash attention using these models for the most part only if you train with batch size 1, see here: #1109

What would you like to see as a solution?

This could be avoided by splitting attention: Instead of running attention with a batch size 2 and a mask, you could run attention 2 times with batch size 1 and no masks.

Iterating through the batches without mask is probably still much faster than masking, because flash attention is used.

This seems to be implemented by musubi tuner here: https://github.com/kohya-ss/musubi-tuner/blob/85d047b99a3b4272da73addb0e047933a9b82624/src/musubi_tuner/hunyuan_model/attention.py#L158

But they implemented it by copying model code from diffusers, which I'd like to avoid.
Possible other ways:

Have you considered alternatives? List them here.

No response

Metadata

Metadata

Assignees

No one assigned

    Labels

    enhancementNew feature or request

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions