Skip to content

[Core] fix QKV fusion for attention #8829

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 19 commits into from
Jul 24, 2024
Merged

[Core] fix QKV fusion for attention #8829

merged 19 commits into from
Jul 24, 2024

Conversation

sayakpaul
Copy link
Member

What does this PR do?

This PR fixes QKV fusion. Since Attention modules are nested in our modules, the QKV fusion processors should be applied recursively.

Additionally it:

  • fixes the implementation of FusedJointAttnProcessor2_0 to properly make use of the fused matrices.
  • adds a new FusedHunyuanAttnProcessor2_0 to respect its use of rotary embeddings.

@HuggingFaceDocBuilderDev

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

@sayakpaul sayakpaul marked this pull request as ready for review July 11, 2024 05:20
@sayakpaul sayakpaul requested review from DN6 and yiyixuxu and removed request for DN6 and yiyixuxu July 11, 2024 05:20
@sayakpaul sayakpaul marked this pull request as draft July 11, 2024 05:28
@sayakpaul sayakpaul requested review from yiyixuxu and DN6 and removed request for yiyixuxu July 11, 2024 05:31
@sayakpaul sayakpaul marked this pull request as ready for review July 11, 2024 05:31
@yiyixuxu
Copy link
Collaborator

so this did not work at all before - can we make sure to test in the future even for optimization PRs
can we see how much the speed up with this fix?

@sayakpaul
Copy link
Member Author

Yeah, Yiyi, I am sorry for the oversight on my part. Fusion of the attention projection matrices becomes relevant when you are quantizing (especially lower precisions like int8). This is because for single projection matrixes, the dimensions are too thin in order for quantization to show its magic.

So, we fuse these projection matrices. We talk a bit about this here: https://pytorch.org/blog/accelerating-generative-ai-3/#dynamic-int8-quantization. In the previous cases, it still worked with quantization when the dimensionality constraints were satisfied. But we didn't get any benefit by thickening the dimensionality of the projection layers in attention.

With the changes in this PR, I am able to obtain the following numbers with quantization.

Without fusion + 8bit quant: Execution time: 6.470 sec
With fusion + 8bit quant: Execution time: 5.944 sec

So, since this is still significant for a small change I would argue.

Code
from diffusers import DiffusionPipeline
import argparse
import torch
import time
import bitsandbytes as bnb
import json

SHORT_NAME_MAPPER = {
    "stabilityai/stable-diffusion-3-medium-diffusers": "sd3",
    "PixArt-alpha/PixArt-Sigma-XL-2-1024-MS": "pixart"
}


def load_pipeline(args):
    pipeline = DiffusionPipeline.from_pretrained(args.ckpt_id, torch_dtype=torch.float16).to("cuda")

    def replace_regular_linears(module, mode="8bit"):
        for name, child in module.named_children():
            if isinstance(child, torch.nn.Linear):
                in_features = child.in_features
                out_features = child.out_features
                device = child.weight.data.device

                # Create and configure the Linear layer
                has_bias = True if child.bias is not None else False
                if mode == "8bit":
                    new_layer = bnb.nn.Linear8bitLt(in_features, out_features, bias=has_bias, has_fp16_weights=False)
                else:
                    # TODO: Make that configurable
                    # fp16 for compute dtype leads to faster inference
                    # and one should almost always use nf4 as a rule of thumb
                    bnb_4bit_compute_dtype = torch.float16
                    quant_type = "nf4"

                    new_layer = bnb.nn.Linear4bit(
                        in_features,
                        out_features,
                        bias=has_bias,
                        compute_dtype=bnb_4bit_compute_dtype,
                        quant_type=quant_type,
                    )

                new_layer.load_state_dict(child.state_dict())
                new_layer = new_layer.to(device)

                # Set the attribute
                setattr(module, name, new_layer)
            else:
                # Recursively apply to child modules
                replace_regular_linears(child, mode=mode)

    if args.fuse:
        pipeline.transformer.fuse_qkv_projections()

    if args.mode is not None:
        replace_regular_linears(pipeline.transformer, args.mode)

    pipeline.set_progress_bar_config(disable=True)
    return pipeline


if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument("--ckpt_id", default="stabilityai/stable-diffusion-3-medium-diffusers", type=str, choices=list(SHORT_NAME_MAPPER.keys()))
    parser.add_argument("--mode", default=None, type=str, choices=["8bit", "4bit"])
    parser.add_argument("--fuse", default=0, type=int, choices=[0, 1])
    parser.add_argument("--prompt", default="a golden vase with different flowers", type=str)
    args = parser.parse_args()

    torch.cuda.empty_cache()
    torch.cuda.reset_peak_memory_stats()
    pipeline = load_pipeline(args)

    for _ in range(5):
        _ = pipeline(args.prompt, generator=torch.manual_seed(2024))

    start = time.time()
    output = pipeline(args.prompt, generator=torch.manual_seed(2024))
    end = time.time()
    mem_bytes = torch.cuda.max_memory_allocated()

    image = output.images[0]
    filename_prefix = f"{SHORT_NAME_MAPPER[args.ckpt_id]}" + "_".join(args.prompt.split(" ")) 
    if args.mode is not None:
        filename_prefix += f"_{args.mode}"
    if args.fuse:
        filename_prefix += f"_fuse@{args.fuse}"
    image.save(f"{filename_prefix}.png")

    print(f"Memory: {mem_bytes/(10**6):.3f} MB")
    print(f"Execution time: {(end - start):.3f} sec")

    info = dict(memory=f"{mem_bytes/(10**6):.3f}", time=f"{(end - start):.3f}")
    with open(f"{filename_prefix}.json", "w") as f:
        json.dump(info, f)

LMK if something is still unclear.

I have tried by best to modify the tests so that we can be more rigorous about these silent bugs. But let me know if you have further ideas.

@sayakpaul sayakpaul requested a review from yiyixuxu July 18, 2024 11:35
Copy link
Collaborator

@yiyixuxu yiyixuxu left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

thanks!
I think it indeed make sense to add this!

@sayakpaul sayakpaul requested a review from yiyixuxu July 19, 2024 12:04
Copy link
Collaborator

@yiyixuxu yiyixuxu left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

thanks! I left one comment about the test
look very good to me otherwise

@sayakpaul
Copy link
Member Author

@yiyixuxu done!

@sayakpaul sayakpaul requested a review from yiyixuxu July 23, 2024 04:44
Copy link
Collaborator

@yiyixuxu yiyixuxu left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

thanks!

@sayakpaul sayakpaul merged commit 50d21f7 into main Jul 24, 2024
16 of 18 checks passed
@sayakpaul sayakpaul deleted the fix-qkv-fusion branch July 24, 2024 01:22
@sayakpaul
Copy link
Member Author

Thank you for bearing with my oversight. Appreciate the patience.

sayakpaul added a commit that referenced this pull request Dec 23, 2024
* start debugging the problem,

* start

* fix

* fix

* fix imports.

* handle hunyuan

* remove residuals.

* add a check for making sure there's appropriate procs.

* add more rigor to the tests.

* fix test

* remove redundant check

* fix-copies

* move check_qkv_fusion_matches_attn_procs_length and check_qkv_fusion_processors_exist.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants