-
-
Notifications
You must be signed in to change notification settings - Fork 258
Description
Describe your use-case.
Torch SDPA is much faster if flash attention can be used - but flash attention doesn't support attention masks. All recent models use attention masking for variable caption lengths.
This means that you benefit from flash attention using these models for the most part only if you train with batch size 1, see here: #1109
What would you like to see as a solution?
This could be avoided by splitting attention: Instead of running attention with a batch size 2 and a mask, you could run attention 2 times with batch size 1 and no masks.
Iterating through the batches without mask is probably still much faster than masking, because flash attention is used.
This seems to be implemented by musubi tuner here: https://github.com/kohya-ss/musubi-tuner/blob/85d047b99a3b4272da73addb0e047933a9b82624/src/musubi_tuner/hunyuan_model/attention.py#L158
But they implemented it by copying model code from diffusers, which I'd like to avoid.
Possible other ways:
- patching torch SDPA like here Avoid attention masks for Qwen and Chroma #1109
- implementing a diffusers attention backend?
- ...?
Have you considered alternatives? List them here.
No response