-
Notifications
You must be signed in to change notification settings - Fork 5.9k
[Core] fix QKV fusion for attention #8829
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update. |
so this did not work at all before - can we make sure to test in the future even for optimization PRs |
Yeah, Yiyi, I am sorry for the oversight on my part. Fusion of the attention projection matrices becomes relevant when you are quantizing (especially lower precisions like int8). This is because for single projection matrixes, the dimensions are too thin in order for quantization to show its magic. So, we fuse these projection matrices. We talk a bit about this here: https://pytorch.org/blog/accelerating-generative-ai-3/#dynamic-int8-quantization. In the previous cases, it still worked with quantization when the dimensionality constraints were satisfied. But we didn't get any benefit by thickening the dimensionality of the projection layers in attention. With the changes in this PR, I am able to obtain the following numbers with quantization. Without fusion + 8bit quant: Execution time: 6.470 sec So, since this is still significant for a small change I would argue. Codefrom diffusers import DiffusionPipeline
import argparse
import torch
import time
import bitsandbytes as bnb
import json
SHORT_NAME_MAPPER = {
"stabilityai/stable-diffusion-3-medium-diffusers": "sd3",
"PixArt-alpha/PixArt-Sigma-XL-2-1024-MS": "pixart"
}
def load_pipeline(args):
pipeline = DiffusionPipeline.from_pretrained(args.ckpt_id, torch_dtype=torch.float16).to("cuda")
def replace_regular_linears(module, mode="8bit"):
for name, child in module.named_children():
if isinstance(child, torch.nn.Linear):
in_features = child.in_features
out_features = child.out_features
device = child.weight.data.device
# Create and configure the Linear layer
has_bias = True if child.bias is not None else False
if mode == "8bit":
new_layer = bnb.nn.Linear8bitLt(in_features, out_features, bias=has_bias, has_fp16_weights=False)
else:
# TODO: Make that configurable
# fp16 for compute dtype leads to faster inference
# and one should almost always use nf4 as a rule of thumb
bnb_4bit_compute_dtype = torch.float16
quant_type = "nf4"
new_layer = bnb.nn.Linear4bit(
in_features,
out_features,
bias=has_bias,
compute_dtype=bnb_4bit_compute_dtype,
quant_type=quant_type,
)
new_layer.load_state_dict(child.state_dict())
new_layer = new_layer.to(device)
# Set the attribute
setattr(module, name, new_layer)
else:
# Recursively apply to child modules
replace_regular_linears(child, mode=mode)
if args.fuse:
pipeline.transformer.fuse_qkv_projections()
if args.mode is not None:
replace_regular_linears(pipeline.transformer, args.mode)
pipeline.set_progress_bar_config(disable=True)
return pipeline
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--ckpt_id", default="stabilityai/stable-diffusion-3-medium-diffusers", type=str, choices=list(SHORT_NAME_MAPPER.keys()))
parser.add_argument("--mode", default=None, type=str, choices=["8bit", "4bit"])
parser.add_argument("--fuse", default=0, type=int, choices=[0, 1])
parser.add_argument("--prompt", default="a golden vase with different flowers", type=str)
args = parser.parse_args()
torch.cuda.empty_cache()
torch.cuda.reset_peak_memory_stats()
pipeline = load_pipeline(args)
for _ in range(5):
_ = pipeline(args.prompt, generator=torch.manual_seed(2024))
start = time.time()
output = pipeline(args.prompt, generator=torch.manual_seed(2024))
end = time.time()
mem_bytes = torch.cuda.max_memory_allocated()
image = output.images[0]
filename_prefix = f"{SHORT_NAME_MAPPER[args.ckpt_id]}" + "_".join(args.prompt.split(" "))
if args.mode is not None:
filename_prefix += f"_{args.mode}"
if args.fuse:
filename_prefix += f"_fuse@{args.fuse}"
image.save(f"{filename_prefix}.png")
print(f"Memory: {mem_bytes/(10**6):.3f} MB")
print(f"Execution time: {(end - start):.3f} sec")
info = dict(memory=f"{mem_bytes/(10**6):.3f}", time=f"{(end - start):.3f}")
with open(f"{filename_prefix}.json", "w") as f:
json.dump(info, f) LMK if something is still unclear. I have tried by best to modify the tests so that we can be more rigorous about these silent bugs. But let me know if you have further ideas. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
thanks!
I think it indeed make sense to add this!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
thanks! I left one comment about the test
look very good to me otherwise
@yiyixuxu done! |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
thanks!
Thank you for bearing with my oversight. Appreciate the patience. |
* start debugging the problem, * start * fix * fix * fix imports. * handle hunyuan * remove residuals. * add a check for making sure there's appropriate procs. * add more rigor to the tests. * fix test * remove redundant check * fix-copies * move check_qkv_fusion_matches_attn_procs_length and check_qkv_fusion_processors_exist.
What does this PR do?
This PR fixes QKV fusion. Since
Attention
modules are nested in our modules, the QKV fusion processors should be applied recursively.Additionally it:
FusedJointAttnProcessor2_0
to properly make use of the fused matrices.FusedHunyuanAttnProcessor2_0
to respect its use of rotary embeddings.