From 796ae1cef53cee0ff3968b3de25cd9bfa06c399c Mon Sep 17 00:00:00 2001 From: Chen Lai Date: Wed, 10 Apr 2024 21:32:41 -0700 Subject: [PATCH] local change to export llama to qnn --- backends/qualcomm/quantizer/quantizer.py | 53 +++++++++++++++++++- backends/qualcomm/quantizer/utils.py | 18 +++++-- examples/models/llama2/export_llama_lib.py | 9 ++-- examples/models/llama2/llama_transformer.py | 55 +++++++++++++++------ 4 files changed, 111 insertions(+), 24 deletions(-) diff --git a/backends/qualcomm/quantizer/quantizer.py b/backends/qualcomm/quantizer/quantizer.py index 674314d991c..85f3c741a2d 100644 --- a/backends/qualcomm/quantizer/quantizer.py +++ b/backends/qualcomm/quantizer/quantizer.py @@ -15,7 +15,8 @@ from executorch.backends.qualcomm.passes.reduce_dynamic_range import ReduceDynamicRange from executorch.backends.qualcomm.passes.remove_clone import RemoveClone from executorch.backends.qualcomm.passes.replace_inf_buffer import ReplaceInfBuffer - +from executorch.backends.qualcomm.passes.convert_constants_to_attrs import ConvertConstantsToAttrs +from torch.ao.quantization.fx.utils import get_new_attr_name_with_prefix from torch import Tensor from torch._ops import OpOverload from torch.ao.quantization.observer import ( @@ -378,8 +379,58 @@ def transform_for_annotation(self, model: GraphModule) -> GraphModule: model = DecomposeScaledDotProductAttention()(model).graph_module model = DecomposeSilu()(model).graph_module model = ReplaceInfBuffer()(model).graph_module + # ConvertConstantsToAttrs(model) + self._lift_constant_scalar_operands(model) + # model = ConvertBinaryOpsWithScalar()(model).graph_module return model + def _lift_constant_scalar_operands(self, gm: torch.fx.GraphModule) -> None: + # print("running _lift_constant_scalar_operands...") + for n in gm.graph.nodes: + # if n.name == "mul_78": + # print(" n.name: ", n.name) + + if n.op != "call_function" or n.target not in ( + torch.ops.aten.add.Tensor, + torch.ops.aten.sub.Tensor, + torch.ops.aten.mul.Tensor, + torch.ops.aten.mul.Scalar, + torch.ops.aten.rsub.Scalar, + ): + continue + + # print(" handling n: ", n, " n.target: ", n.target, " n.args: ", n.args) + const_arg = None + non_const_arg = None + for arg in n.args: + if isinstance(arg, torch.fx.Node): + non_const_arg = arg + else: + const_arg = arg + + if non_const_arg is None or const_arg is None: + continue + + # print(" n'args are all constant: ", n) + tensor_constant = torch.tensor([const_arg], dtype=torch.float32) + tensor_constant_name = get_new_attr_name_with_prefix("_tensor_constant_")( + gm + ) + gm.register_buffer(tensor_constant_name, tensor_constant) + + fake_mode = n.meta["val"].fake_mode + with gm.graph.inserting_before(n): + get_attr_node = gm.graph.get_attr(tensor_constant_name) + get_attr_node.meta["val"] = fake_mode.from_tensor(tensor_constant) + + if n.target == torch.ops.aten.rsub.Scalar: + n.args = (get_attr_node, non_const_arg) + n.args[2:] + n.target = torch.ops.aten.sub.Tensor + else: + n.args = (non_const_arg, get_attr_node) + n.args[2:] + + gm.recompile() + def validate(self, model: GraphModule) -> None: pass diff --git a/backends/qualcomm/quantizer/utils.py b/backends/qualcomm/quantizer/utils.py index 809b7298eba..abb33f18bc5 100644 --- a/backends/qualcomm/quantizer/utils.py +++ b/backends/qualcomm/quantizer/utils.py @@ -7,6 +7,7 @@ from typing import Callable, Dict, List, Optional, Sequence import torch +from torch._subclasses import FakeTensor from torch._ops import OpOverload @@ -41,6 +42,13 @@ def decorator(annotator: Callable): return decorator +def _is_input_non_float_tensor(node: Node): + """Check if the input is not a float tensor, so that we can skip quantization for the node + since observers only works with float Tensors + """ + if "val" not in node.meta or not isinstance(node.meta["val"], FakeTensor): + return True + return node.meta["val"].dtype != torch.float32 def _is_annotated(nodes: List[Node]): """ @@ -115,6 +123,7 @@ def annotate_single_in_single_out( def annotate_binary(node: Node, quantization_config: QuantizationConfig) -> None: + print(f"annotate_binary running for node {node}...") if _is_annotated([node]): return @@ -123,12 +132,14 @@ def annotate_binary(node: Node, quantization_config: QuantizationConfig) -> None input_qspec_map = {} input_act0 = node.args[0] - if isinstance(input_act0, Node): + if isinstance(input_act0, Node) and not _is_input_non_float_tensor(input_act0): input_qspec_map[input_act0] = input_act_qspec + print(" input_act0: ", input_act0, " _is_input_non_float_tensor: ", _is_input_non_float_tensor(input_act0)) input_act1 = node.args[1] - if isinstance(input_act1, Node): + if isinstance(input_act1, Node) and not _is_input_non_float_tensor(input_act1): input_qspec_map[input_act1] = input_act_qspec + print(" input_act1: ", input_act1, " _is_input_non_float_tensor: ", _is_input_non_float_tensor(input_act1)) node.meta[QUANT_ANNOTATION_KEY] = QuantizationAnnotation( input_qspec_map=input_qspec_map, @@ -147,7 +158,8 @@ def annotate_sub(node: Node, quantization_config: QuantizationConfig) -> None: annotate_binary(node, quantization_config) -@register_annotator([torch.ops.aten.mul.Tensor, torch.ops.aten.mul.Scalar]) +# @register_annotator([torch.ops.aten.mul.Tensor, torch.ops.aten.mul.Scalar]) +@register_annotator([torch.ops.aten.mul.Tensor]) def annotate_mul(node: Node, quantization_config: QuantizationConfig) -> None: annotate_binary(node, quantization_config) diff --git a/examples/models/llama2/export_llama_lib.py b/examples/models/llama2/export_llama_lib.py index de1e711a2c9..b9b9c104552 100644 --- a/examples/models/llama2/export_llama_lib.py +++ b/examples/models/llama2/export_llama_lib.py @@ -647,17 +647,17 @@ def _export_llama(modelname, args) -> str: # noqa: C901 # pyre-ignore: Undefined attribute [16]: Module `executorch.backends` has no attribute `qualcomm` generate_qnn_executorch_compiler_spec( # pyre-ignore: Undefined attribute [16]: Module `executorch.backends` has no attribute `qualcomm`. - soc_model=QcomChipset.SM8650, # default to SM8650 + soc_model=QcomChipset.SM8450, # default to SM8650 backend_options=backend_options, debug=False, saver=False, ), skip_node_id_set={}, - skip_node_op_set={}, + skip_node_op_set={"aten.unsqueeze_copy.default", "aten.permute_copy.default"}, ) ) # pyre-ignore: Undefined attribute [16]: Module `executorch.backends` has no attribute `qualcomm` - _transform(builder_exported_to_edge.export_program()) + _transform(builder_exported_to_edge.edge_manager.exported_program()) if args.generate_etrecord: if not builder_exported_to_edge.edge_manager: @@ -678,7 +678,8 @@ def _export_llama(modelname, args) -> str: # noqa: C901 logging.info("Generated etrecord.bin") else: builder = builder_exported_to_edge.to_backend(partitioners).to_executorch() - + print("graph after to_backend") + builder.edge_manager.exported_program().graph.print_tabular() if args.profile_memory: generate_memory_trace(builder.export_program, "memory_profile.json") diff --git a/examples/models/llama2/llama_transformer.py b/examples/models/llama2/llama_transformer.py index 2a259af59cb..57fbdbd3fc1 100644 --- a/examples/models/llama2/llama_transformer.py +++ b/examples/models/llama2/llama_transformer.py @@ -14,7 +14,7 @@ import torch.nn.functional as F from torch import nn - +import math class RMSNorm(torch.nn.Module): def __init__(self, dim: int, eps: float = 1e-6): @@ -216,15 +216,23 @@ def __init__(self, args: ModelArgs, layer_id: int): self.use_sdpa_with_kv_cache_op = args.use_sdpa_with_kv_cache_op self.layer_id = layer_id - causal_mask = torch.tril( - torch.ones( - self.max_seq_len, - self.max_seq_len, - dtype=torch.bool, - device="cpu", - ) + # causal_mask = torch.tril( + # torch.ones( + # self.max_seq_len, + # self.max_seq_len, + # dtype=torch.bool, + # device="cpu", + # ) + # ) + # self.register_buffer("mask", causal_mask, persistent=False) + mask = torch.full( + (1, 1, args.max_seq_len, args.max_seq_len), + float("-inf"), + device="cpu", ) - self.register_buffer("mask", causal_mask, persistent=False) + + mask = torch.triu(mask, diagonal=1) + self.register_buffer("mask", mask) if self.use_kv_cache: self.kv_cache = KVCache( @@ -264,18 +272,33 @@ def forward( v = v.transpose(1, 2) k, v = self.kv_cache.update(input_pos, k, v) - mask = self.mask[None, None, input_pos] + mask = torch.squeeze(self.mask, [0, 1]) + mask = mask[None, None, input_pos] + # mask = self.mask[None, None, input_pos] k = k.repeat_interleave(self.n_rep, dim=1) v = v.repeat_interleave(self.n_rep, dim=1) - y = F.scaled_dot_product_attention( - q, k, v, attn_mask=mask, dropout_p=0.0 - ) - y = y.transpose(1, 2).contiguous().view(bsz, seqlen, self.dim) - y = self.wo(y) - return y + scores = torch.matmul(q, k.transpose(2, 3)) / math.sqrt(self.head_dim) + scores = F.softmax(scores.float(), dim=-1).type_as(q) + scores = scores + mask + output = torch.matmul( + scores, v + ) # (bs, n_local_heads, seqlen, head_dim) + + output = output.transpose(1, 2).contiguous().view(bsz, seqlen, -1) + + output = self.wo(output) + return output + # y = F.scaled_dot_product_attention( + # q, k, v, attn_mask=mask, dropout_p=0.0 + # ) + + # y = y.transpose(1, 2).contiguous().view(bsz, seqlen, self.dim) + + # y = self.wo(y) + # return y else: from .custom_ops.sdpa_with_kv_cache import sdpa_with_kv_cache # noqa