From 2f85f9e5776557e1353fa2d8ea5642588f007d43 Mon Sep 17 00:00:00 2001 From: Chun-I Tsai Date: Wed, 2 Oct 2024 06:11:33 +0530 Subject: [PATCH 1/3] Qualcomm AI Engine Direct - Add llama sha transforming pass - Add SHA pass --- examples/models/llama/export_llama.py | 3 + examples/models/llama/export_llama_lib.py | 38 ++- examples/models/llama/llama_transformer.py | 1 - .../llama/source_transformation/attention.py | 219 ++++++++++++++++++ 4 files changed, 251 insertions(+), 10 deletions(-) create mode 100644 examples/models/llama/source_transformation/attention.py diff --git a/examples/models/llama/export_llama.py b/examples/models/llama/export_llama.py index 3d0d1b7bcfb..1899ccf4df6 100644 --- a/examples/models/llama/export_llama.py +++ b/examples/models/llama/export_llama.py @@ -7,11 +7,14 @@ # Example script for exporting Llama2 to flatbuffer import logging +import sys import torch from .export_llama_lib import build_args_parser, export_llama +sys.setrecursionlimit(4096) + FORMAT = "[%(levelname)s %(asctime)s %(filename)s:%(lineno)s] %(message)s" logging.basicConfig(level=logging.INFO, format=FORMAT) diff --git a/examples/models/llama/export_llama_lib.py b/examples/models/llama/export_llama_lib.py index 23b3589c2a0..9c3f6d81b01 100644 --- a/examples/models/llama/export_llama_lib.py +++ b/examples/models/llama/export_llama_lib.py @@ -50,6 +50,8 @@ fuse_layer_norms, get_model_with_r1_r2, ) + +from .source_transformation.attention import replace_attention_to_attention_sha from .source_transformation.quantize import ( get_quant_embedding_transform, get_quant_weight_transform, @@ -175,6 +177,12 @@ def build_args_parser() -> argparse.ArgumentParser: help="checkpoint directory. Use with a sharded checkpoint, not for the standard llama2 model. Note, checkpoint_dir takes precedence over checkpoint if both are set.", ) + parser.add_argument( + "--use_qnn_sha", + action="store_true", + help="Change multi head attention to multiple single head attention for qnn backend (Qualcomm)", + ) + parser.add_argument( "--calibration_tasks", nargs="+", @@ -947,15 +955,27 @@ def _get_source_transforms( # noqa convert_linear_to_conv2d, ) - transforms.append(replace_kv_cache_with_simple_kv_cache) - transforms.append(replace_sdpa_with_flex_sdpa) - transforms.append(replace_causal_mask) - transforms.append(replace_rms_norm_with_native_rms_norm) - if args.optimized_rotation_path: - transforms.append(fuse_layer_norms) - transforms.append(get_model_with_r1_r2(args.optimized_rotation_path)) - # pyre-fixme[16]: Module `backends` has no attribute `qualcomm`. - transforms.append(convert_linear_to_conv2d) + if args.use_qnn_sha: + if args.optimized_rotation_path: + transforms.append(fuse_layer_norms) + transforms.append( + get_model_with_r1_r2(args.optimized_rotation_path) + ) + transforms.append(replace_attention_to_attention_sha) + transforms.append(replace_causal_mask) + transforms.append(replace_rms_norm_with_native_rms_norm) + transforms.append(convert_linear_to_conv2d) + else: + transforms.append(replace_kv_cache_with_simple_kv_cache) + transforms.append(replace_sdpa_with_flex_sdpa) + transforms.append(replace_causal_mask) + transforms.append(replace_rms_norm_with_native_rms_norm) + if args.optimized_rotation_path: + transforms.append(fuse_layer_norms) + transforms.append( + get_model_with_r1_r2(args.optimized_rotation_path) + ) + transforms.append(convert_linear_to_conv2d) elif args.mps: # Currently mps doesn't support sdpa op, use the simpler decomposition diff --git a/examples/models/llama/llama_transformer.py b/examples/models/llama/llama_transformer.py index 76e8730328b..20b8b1e30d4 100644 --- a/examples/models/llama/llama_transformer.py +++ b/examples/models/llama/llama_transformer.py @@ -276,7 +276,6 @@ def __init__(self, args: ModelArgs, layer_id: int): self.max_batch_size = args.max_batch_size self.max_seq_len = args.max_seq_len self.dim = args.dim - # self.dim = 4096, self.n_heads = 32, self.head_dim = 4096 / 32 = 125 self.wq = nn.Linear(self.dim, self.n_heads * self.head_dim, bias=False) self.wk = nn.Linear(self.dim, self.n_kv_heads * self.head_dim, bias=False) self.wv = nn.Linear(self.dim, self.n_kv_heads * self.head_dim, bias=False) diff --git a/examples/models/llama/source_transformation/attention.py b/examples/models/llama/source_transformation/attention.py new file mode 100644 index 00000000000..de368bbdc59 --- /dev/null +++ b/examples/models/llama/source_transformation/attention.py @@ -0,0 +1,219 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +# pyre-unsafe + +# Example script for exporting Llama2 to flatbuffer + +import math +from typing import List, Optional, Tuple + +import torch +from executorch.examples.models.llama2.llama_transformer import Attention +from torch import nn + + +def apply_rotary_emb_single( + x: torch.Tensor, freqs_cos: torch.Tensor, freqs_sin: torch.Tensor +) -> torch.Tensor: + x_r, x_i = x[..., ::2], x[..., 1::2] + + x_out_r = x_r * freqs_cos - x_i * freqs_sin + x_out_i = x_r * freqs_sin + x_i * freqs_cos + + x_out = torch.cat([x_out_r, x_out_i], dim=-1) + return x_out + + +class KVCacheSha(torch.nn.Module): + def __init__( + self, + max_batch_size: int, + max_seq_length: int, + n_heads: int, + head_dim: int, + dtype=torch.float32, + ): + super().__init__() + + # a buffer per head + cache_shape = (max_batch_size, max_seq_length, head_dim) + for i in range(n_heads): + self.register_buffer( + f"past_k_caches_{i}", + torch.zeros(cache_shape, dtype=dtype, device="cpu"), + persistent=False, + ) + self.register_buffer( + f"past_v_caches_{i}", + torch.zeros(cache_shape, dtype=dtype, device="cpu"), + persistent=False, + ) + + def update( + self, + input_pos: torch.Tensor, + k_val: torch.Tensor, + v_val: torch.Tensor, + cache_idx: int, + ) -> Tuple[torch.Tensor, torch.Tensor]: + new_k = torch.ops.aten.index_put_( + getattr(self, f"past_k_caches_{cache_idx}"), [None, input_pos], k_val + ) + new_v = torch.ops.aten.index_put_( + getattr(self, f"past_v_caches_{cache_idx}"), [None, input_pos], v_val + ) + return new_k, new_v + + def get_cache(self, head_idx): + return getattr(self, f"past_k_caches_{head_idx}"), getattr( + self, f"past_v_caches_{head_idx}" + ) + + +class SDPASha(torch.nn.Module): + + def __init__( + self, + max_batch_size: int, + max_seq_length: int, + n_heads: int, + n_rep: int, + head_dim: int, + dim: int, + ): + super().__init__() + self.head_dim = head_dim + self.n_rep = n_rep + self.dim = dim + self.kv_cache = KVCacheSha( + max_batch_size, max_seq_length, n_heads // n_rep, head_dim + ) + self.scale_factor = math.sqrt(head_dim) + + def forward( + self, + input_pos: torch.Tensor, + qs: List[torch.Tensor], + ks: List[torch.Tensor], + vs: List[torch.Tensor], + mask, + ): + + transpose_ks = [] + for i in range(len(ks)): + new_k, _ = self.kv_cache.update(input_pos, ks[i], vs[i], i) + transpose_ks.append(new_k.transpose(-2, -1).contiguous()) + + output = [] + for i, q in enumerate(qs): + cache_idx = i // self.n_rep + _, v = self.kv_cache.get_cache(cache_idx) + + attn_mask = mask[input_pos] + + attn_weight = q @ transpose_ks[cache_idx] / self.scale_factor + attn_weight += attn_mask + attn_weight = torch.softmax(attn_weight, dim=-1) + output.append(attn_weight @ v.contiguous()) + + return torch.cat(output, dim=-1) + + +class AttentionSha(nn.Module): + def __init__(self, attention_mha: nn.Module): + super().__init__() + if not attention_mha.use_kv_cache: + raise NotImplementedError("bert mode is not support") + + self.n_heads = attention_mha.n_heads + self.n_kv_heads = attention_mha.n_kv_heads + self.n_rep = self.n_heads // self.n_kv_heads + self.dim = attention_mha.dim + self.max_batch_size = attention_mha.max_batch_size + self.max_seq_len = attention_mha.max_seq_len + self.head_dim = attention_mha.dim // self.n_heads + self.SDPA = SDPASha( + self.max_batch_size, + self.max_seq_len, + self.n_heads, + self.n_rep, + self.head_dim, + self.dim, + ) + self.wq = nn.ModuleList( + [ + nn.Linear(self.dim, self.head_dim, bias=False) + for _ in range(self.n_heads) + ] + ) + self.wk = nn.ModuleList( + [ + nn.Linear(self.dim, self.head_dim, bias=False) + for _ in range(self.n_kv_heads) + ] + ) + self.wv = nn.ModuleList( + [ + nn.Linear(self.dim, self.head_dim, bias=False) + for _ in range(self.n_kv_heads) + ] + ) + + for i in range(self.n_heads): + self.wq[i].weight.data.copy_( + attention_mha.wq.weight[i * self.head_dim : (i + 1) * self.head_dim] + ) + for i in range(self.n_kv_heads): + self.wk[i].weight.data.copy_( + attention_mha.wk.weight[i * self.head_dim : (i + 1) * self.head_dim] + ) + self.wv[i].weight.data.copy_( + attention_mha.wv.weight[i * self.head_dim : (i + 1) * self.head_dim] + ) + self.wo = attention_mha.wo + + causal_mask = torch.tril( + torch.ones( + self.max_seq_len, + self.max_seq_len, + dtype=torch.bool, + device="cpu", + ) + ) + self.register_buffer("mask", causal_mask, persistent=False) + + def forward( + self, + x: torch.Tensor, + freqs_cos: torch.Tensor, + freqs_sin: torch.Tensor, + input_pos: Optional[torch.Tensor] = None, + ): + # QKV + q = [wq(x) for wq in self.wq] + k = [wk(x) for wk in self.wk] + v = [wv(x) for wv in self.wv] + for i in range(len(q)): + q[i] = apply_rotary_emb_single(q[i], freqs_cos, freqs_sin) + for i in range(len(k)): + k[i] = apply_rotary_emb_single(k[i], freqs_cos, freqs_sin) + + output = self.SDPA(input_pos, q, k, v, self.mask) + return self.wo(output) + + +def replace_attention_to_attention_sha(module: torch.nn.Module): + for name, child in module.named_children(): + if isinstance(child, Attention): + setattr( + module, + name, + AttentionSha(child), + ) + else: + replace_attention_to_attention_sha(child) + return module From 8bceee8847d6554921740f2287951f482e91940f Mon Sep 17 00:00:00 2001 From: Chun-I Tsai Date: Fri, 1 Nov 2024 08:50:30 +0530 Subject: [PATCH 2/3] Rebase and change class name *Sha-> SHA --- examples/models/llama/export_llama_lib.py | 21 +++++++++++++------ .../llama/source_transformation/attention.py | 14 ++++++------- 2 files changed, 22 insertions(+), 13 deletions(-) diff --git a/examples/models/llama/export_llama_lib.py b/examples/models/llama/export_llama_lib.py index 9c3f6d81b01..b2d971d2699 100644 --- a/examples/models/llama/export_llama_lib.py +++ b/examples/models/llama/export_llama_lib.py @@ -678,15 +678,24 @@ def _export_llama(modelname, args) -> LLMEdgeManager: # noqa: C901 get_custom_quant_ios_dtype, ) + atten = builder_exported_to_edge.model.layers[0].attention + if args.use_qnn_sha: + cache_shape = torch.Size( + (atten.max_batch_size, atten.max_seq_len, atten.head_dim) + ) + else: + cache_shape = torch.Size( + ( + atten.max_batch_size, + atten.max_seq_len, + atten.n_kv_heads, + atten.head_dim, + ) + ) # pyre-ignore tag_quant_io( builder_exported_to_edge.edge_manager.exported_program().graph_module, - partial( - get_custom_quant_ios_dtype, # pyre-ignore - builder_exported_to_edge.model.layers[ - 0 - ].attention.kv_cache.past_k_caches.shape, - ), + partial(get_custom_quant_ios_dtype, cache_shape), # pyre-ignore ) logging.info("Lowering model using following partitioner(s): ") diff --git a/examples/models/llama/source_transformation/attention.py b/examples/models/llama/source_transformation/attention.py index de368bbdc59..c5a028d3400 100644 --- a/examples/models/llama/source_transformation/attention.py +++ b/examples/models/llama/source_transformation/attention.py @@ -12,7 +12,7 @@ from typing import List, Optional, Tuple import torch -from executorch.examples.models.llama2.llama_transformer import Attention +from executorch.examples.models.llama.llama_transformer import Attention from torch import nn @@ -28,7 +28,7 @@ def apply_rotary_emb_single( return x_out -class KVCacheSha(torch.nn.Module): +class KVCacheSHA(torch.nn.Module): def __init__( self, max_batch_size: int, @@ -74,7 +74,7 @@ def get_cache(self, head_idx): ) -class SDPASha(torch.nn.Module): +class SDPASHA(torch.nn.Module): def __init__( self, @@ -89,7 +89,7 @@ def __init__( self.head_dim = head_dim self.n_rep = n_rep self.dim = dim - self.kv_cache = KVCacheSha( + self.kv_cache = KVCacheSHA( max_batch_size, max_seq_length, n_heads // n_rep, head_dim ) self.scale_factor = math.sqrt(head_dim) @@ -123,7 +123,7 @@ def forward( return torch.cat(output, dim=-1) -class AttentionSha(nn.Module): +class AttentionSHA(nn.Module): def __init__(self, attention_mha: nn.Module): super().__init__() if not attention_mha.use_kv_cache: @@ -136,7 +136,7 @@ def __init__(self, attention_mha: nn.Module): self.max_batch_size = attention_mha.max_batch_size self.max_seq_len = attention_mha.max_seq_len self.head_dim = attention_mha.dim // self.n_heads - self.SDPA = SDPASha( + self.SDPA = SDPASHA( self.max_batch_size, self.max_seq_len, self.n_heads, @@ -212,7 +212,7 @@ def replace_attention_to_attention_sha(module: torch.nn.Module): setattr( module, name, - AttentionSha(child), + AttentionSHA(child), ) else: replace_attention_to_attention_sha(child) From 607bc6ceff890bdfeb3d7e7aa649cefb11130983 Mon Sep 17 00:00:00 2001 From: Joey Tsai Date: Mon, 11 Nov 2024 12:04:51 +0800 Subject: [PATCH 3/3] Fix internal test --- examples/models/llama/TARGETS | 1 + 1 file changed, 1 insertion(+) diff --git a/examples/models/llama/TARGETS b/examples/models/llama/TARGETS index d328adffbf7..cf387bfab24 100644 --- a/examples/models/llama/TARGETS +++ b/examples/models/llama/TARGETS @@ -82,6 +82,7 @@ runtime.python_library( "export_llama_lib.py", "model.py", "source_transformation/apply_spin_quant_r1_r2.py", + "source_transformation/attention.py", "source_transformation/lora.py", "source_transformation/pre_quantization.py", "source_transformation/prune_vocab.py",