Skip to content

Qualcomm AI Engine Direct - Add llama sha transforming pass #6211

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions examples/models/llama/TARGETS
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,7 @@ runtime.python_library(
"export_llama_lib.py",
"model.py",
"source_transformation/apply_spin_quant_r1_r2.py",
"source_transformation/attention.py",
"source_transformation/lora.py",
"source_transformation/pre_quantization.py",
"source_transformation/prune_vocab.py",
Expand Down
3 changes: 3 additions & 0 deletions examples/models/llama/export_llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,11 +7,14 @@
# Example script for exporting Llama2 to flatbuffer

import logging
import sys

import torch

from .export_llama_lib import build_args_parser, export_llama

sys.setrecursionlimit(4096)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is it still required in the latest commit?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes it is needed when enable use_qnn_sha. Otherwise will trigger maximum recursion depth at the prepare_pt2e funciton.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you add a comment and explain the reason? Also how likely we can guard it to args.qnn only?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sorry for late reply. I was running an experiment about the comment.
I believe we can move this line and import sys to a qualcomm specific condition

@ def get_quantizer_and_quant_params(args):
@@ -557,6 +557,8 @
     quantizers = get_pt2e_quantizers(pt2e_quant_params, args.so_library)
     quant_dtype = None
     if args.qnn and args.pt2e_quantize:
+        import sys
+        sys.setrecursionlimit(4096)

It can guard it to args.qnn only. If this one looks better, I will raise a PR to move it and add comment



FORMAT = "[%(levelname)s %(asctime)s %(filename)s:%(lineno)s] %(message)s"
logging.basicConfig(level=logging.INFO, format=FORMAT)
Expand Down
59 changes: 44 additions & 15 deletions examples/models/llama/export_llama_lib.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,8 @@
fuse_layer_norms,
get_model_with_r1_r2,
)

from .source_transformation.attention import replace_attention_to_attention_sha
from .source_transformation.quantize import (
get_quant_embedding_transform,
get_quant_weight_transform,
Expand Down Expand Up @@ -175,6 +177,12 @@ def build_args_parser() -> argparse.ArgumentParser:
help="checkpoint directory. Use with a sharded checkpoint, not for the standard llama2 model. Note, checkpoint_dir takes precedence over checkpoint if both are set.",
)

parser.add_argument(
"--use_qnn_sha",
action="store_true",
help="Change multi head attention to multiple single head attention for qnn backend (Qualcomm)",
)

parser.add_argument(
"--calibration_tasks",
nargs="+",
Expand Down Expand Up @@ -670,15 +678,24 @@ def _export_llama(modelname, args) -> LLMEdgeManager: # noqa: C901
get_custom_quant_ios_dtype,
)

atten = builder_exported_to_edge.model.layers[0].attention
if args.use_qnn_sha:
cache_shape = torch.Size(
(atten.max_batch_size, atten.max_seq_len, atten.head_dim)
)
else:
cache_shape = torch.Size(
(
atten.max_batch_size,
atten.max_seq_len,
atten.n_kv_heads,
atten.head_dim,
)
)
# pyre-ignore
tag_quant_io(
builder_exported_to_edge.edge_manager.exported_program().graph_module,
partial(
get_custom_quant_ios_dtype, # pyre-ignore
builder_exported_to_edge.model.layers[
0
].attention.kv_cache.past_k_caches.shape,
),
partial(get_custom_quant_ios_dtype, cache_shape), # pyre-ignore
)

logging.info("Lowering model using following partitioner(s): ")
Expand Down Expand Up @@ -947,15 +964,27 @@ def _get_source_transforms( # noqa
convert_linear_to_conv2d,
)

transforms.append(replace_kv_cache_with_simple_kv_cache)
transforms.append(replace_sdpa_with_flex_sdpa)
transforms.append(replace_causal_mask)
transforms.append(replace_rms_norm_with_native_rms_norm)
if args.optimized_rotation_path:
transforms.append(fuse_layer_norms)
transforms.append(get_model_with_r1_r2(args.optimized_rotation_path))
# pyre-fixme[16]: Module `backends` has no attribute `qualcomm`.
transforms.append(convert_linear_to_conv2d)
if args.use_qnn_sha:
if args.optimized_rotation_path:
transforms.append(fuse_layer_norms)
transforms.append(
get_model_with_r1_r2(args.optimized_rotation_path)
)
transforms.append(replace_attention_to_attention_sha)
transforms.append(replace_causal_mask)
transforms.append(replace_rms_norm_with_native_rms_norm)
transforms.append(convert_linear_to_conv2d)
else:
transforms.append(replace_kv_cache_with_simple_kv_cache)
transforms.append(replace_sdpa_with_flex_sdpa)
transforms.append(replace_causal_mask)
transforms.append(replace_rms_norm_with_native_rms_norm)
if args.optimized_rotation_path:
transforms.append(fuse_layer_norms)
transforms.append(
get_model_with_r1_r2(args.optimized_rotation_path)
)
transforms.append(convert_linear_to_conv2d)

elif args.mps:
# Currently mps doesn't support sdpa op, use the simpler decomposition
Expand Down
1 change: 0 additions & 1 deletion examples/models/llama/llama_transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -276,7 +276,6 @@ def __init__(self, args: ModelArgs, layer_id: int):
self.max_batch_size = args.max_batch_size
self.max_seq_len = args.max_seq_len
self.dim = args.dim
# self.dim = 4096, self.n_heads = 32, self.head_dim = 4096 / 32 = 125
self.wq = nn.Linear(self.dim, self.n_heads * self.head_dim, bias=False)
self.wk = nn.Linear(self.dim, self.n_kv_heads * self.head_dim, bias=False)
self.wv = nn.Linear(self.dim, self.n_kv_heads * self.head_dim, bias=False)
Expand Down
219 changes: 219 additions & 0 deletions examples/models/llama/source_transformation/attention.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,219 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

# pyre-unsafe

# Example script for exporting Llama2 to flatbuffer

import math
from typing import List, Optional, Tuple

import torch
from executorch.examples.models.llama.llama_transformer import Attention
from torch import nn


def apply_rotary_emb_single(
x: torch.Tensor, freqs_cos: torch.Tensor, freqs_sin: torch.Tensor
) -> torch.Tensor:
x_r, x_i = x[..., ::2], x[..., 1::2]

x_out_r = x_r * freqs_cos - x_i * freqs_sin
x_out_i = x_r * freqs_sin + x_i * freqs_cos

x_out = torch.cat([x_out_r, x_out_i], dim=-1)
return x_out


class KVCacheSHA(torch.nn.Module):
def __init__(
self,
max_batch_size: int,
max_seq_length: int,
n_heads: int,
head_dim: int,
dtype=torch.float32,
):
super().__init__()

# a buffer per head
cache_shape = (max_batch_size, max_seq_length, head_dim)
for i in range(n_heads):
self.register_buffer(
f"past_k_caches_{i}",
torch.zeros(cache_shape, dtype=dtype, device="cpu"),
persistent=False,
)
self.register_buffer(
f"past_v_caches_{i}",
torch.zeros(cache_shape, dtype=dtype, device="cpu"),
persistent=False,
)

def update(
self,
input_pos: torch.Tensor,
k_val: torch.Tensor,
v_val: torch.Tensor,
cache_idx: int,
) -> Tuple[torch.Tensor, torch.Tensor]:
new_k = torch.ops.aten.index_put_(
getattr(self, f"past_k_caches_{cache_idx}"), [None, input_pos], k_val
)
new_v = torch.ops.aten.index_put_(
getattr(self, f"past_v_caches_{cache_idx}"), [None, input_pos], v_val
)
return new_k, new_v

def get_cache(self, head_idx):
return getattr(self, f"past_k_caches_{head_idx}"), getattr(
self, f"past_v_caches_{head_idx}"
)


class SDPASHA(torch.nn.Module):

def __init__(
self,
max_batch_size: int,
max_seq_length: int,
n_heads: int,
n_rep: int,
head_dim: int,
dim: int,
):
super().__init__()
self.head_dim = head_dim
self.n_rep = n_rep
self.dim = dim
self.kv_cache = KVCacheSHA(
max_batch_size, max_seq_length, n_heads // n_rep, head_dim
)
self.scale_factor = math.sqrt(head_dim)

def forward(
self,
input_pos: torch.Tensor,
qs: List[torch.Tensor],
ks: List[torch.Tensor],
vs: List[torch.Tensor],
mask,
):

transpose_ks = []
for i in range(len(ks)):
new_k, _ = self.kv_cache.update(input_pos, ks[i], vs[i], i)
transpose_ks.append(new_k.transpose(-2, -1).contiguous())

output = []
for i, q in enumerate(qs):
cache_idx = i // self.n_rep
_, v = self.kv_cache.get_cache(cache_idx)

attn_mask = mask[input_pos]

attn_weight = q @ transpose_ks[cache_idx] / self.scale_factor
attn_weight += attn_mask
attn_weight = torch.softmax(attn_weight, dim=-1)
output.append(attn_weight @ v.contiguous())

return torch.cat(output, dim=-1)


class AttentionSHA(nn.Module):
def __init__(self, attention_mha: nn.Module):
super().__init__()
if not attention_mha.use_kv_cache:
raise NotImplementedError("bert mode is not support")

self.n_heads = attention_mha.n_heads
self.n_kv_heads = attention_mha.n_kv_heads
self.n_rep = self.n_heads // self.n_kv_heads
self.dim = attention_mha.dim
self.max_batch_size = attention_mha.max_batch_size
self.max_seq_len = attention_mha.max_seq_len
self.head_dim = attention_mha.dim // self.n_heads
self.SDPA = SDPASHA(
self.max_batch_size,
self.max_seq_len,
self.n_heads,
self.n_rep,
self.head_dim,
self.dim,
)
self.wq = nn.ModuleList(
[
nn.Linear(self.dim, self.head_dim, bias=False)
for _ in range(self.n_heads)
]
)
self.wk = nn.ModuleList(
[
nn.Linear(self.dim, self.head_dim, bias=False)
for _ in range(self.n_kv_heads)
]
)
self.wv = nn.ModuleList(
[
nn.Linear(self.dim, self.head_dim, bias=False)
for _ in range(self.n_kv_heads)
]
)

for i in range(self.n_heads):
self.wq[i].weight.data.copy_(
attention_mha.wq.weight[i * self.head_dim : (i + 1) * self.head_dim]
)
for i in range(self.n_kv_heads):
self.wk[i].weight.data.copy_(
attention_mha.wk.weight[i * self.head_dim : (i + 1) * self.head_dim]
)
self.wv[i].weight.data.copy_(
attention_mha.wv.weight[i * self.head_dim : (i + 1) * self.head_dim]
)
self.wo = attention_mha.wo

causal_mask = torch.tril(
torch.ones(
self.max_seq_len,
self.max_seq_len,
dtype=torch.bool,
device="cpu",
)
)
self.register_buffer("mask", causal_mask, persistent=False)

def forward(
self,
x: torch.Tensor,
freqs_cos: torch.Tensor,
freqs_sin: torch.Tensor,
input_pos: Optional[torch.Tensor] = None,
):
# QKV
q = [wq(x) for wq in self.wq]
k = [wk(x) for wk in self.wk]
v = [wv(x) for wv in self.wv]
for i in range(len(q)):
q[i] = apply_rotary_emb_single(q[i], freqs_cos, freqs_sin)
for i in range(len(k)):
k[i] = apply_rotary_emb_single(k[i], freqs_cos, freqs_sin)

output = self.SDPA(input_pos, q, k, v, self.mask)
return self.wo(output)


def replace_attention_to_attention_sha(module: torch.nn.Module):
for name, child in module.named_children():
if isinstance(child, Attention):
setattr(
module,
name,
AttentionSHA(child),
)
else:
replace_attention_to_attention_sha(child)
return module
Loading