From 37c471dbee3e463292c16a709f0fe25066c89860 Mon Sep 17 00:00:00 2001 From: shewu-quic Date: Sun, 20 Oct 2024 18:57:01 -0700 Subject: [PATCH 1/5] qio + sha + cpu quantized embedding +r1r2 cuda --- backends/qualcomm/_passes/i64_to_i32.py | 7 +- backends/qualcomm/partition/common_defs.py | 1 + .../qualcomm/quantizer/custom_annotation.py | 33 +++ backends/qualcomm/utils/utils.py | 10 + examples/models/llama/export_llama.py | 3 + examples/models/llama/export_llama_lib.py | 65 +++++- examples/models/llama/llama_transformer.py | 19 +- .../apply_spin_quant_r1_r2.py | 2 +- .../llama/source_transformation/attention.py | 219 ++++++++++++++++++ extension/llm/export/builder.py | 18 +- 10 files changed, 348 insertions(+), 29 deletions(-) create mode 100644 examples/models/llama/source_transformation/attention.py diff --git a/backends/qualcomm/_passes/i64_to_i32.py b/backends/qualcomm/_passes/i64_to_i32.py index 1d2171cc37a..d4ebc40c59e 100644 --- a/backends/qualcomm/_passes/i64_to_i32.py +++ b/backends/qualcomm/_passes/i64_to_i32.py @@ -61,7 +61,12 @@ def _cast_to_int32(self, graph_module: torch.fx.GraphModule): to_dst_node.meta["val"] = node_val.to(torch.int32) # Replace usage of the src dtype result with the dst dtype result. - n.replace_all_uses_with(to_dst_node) + if n.name != "tokens": + n.replace_all_uses_with(to_dst_node) + else: + for user in n.users.copy(): + if user.name != "quantized_decomposed_embedding_4bit_dtype": + user.replace_input_with(n, to_dst_node) to_dst_node.args = (n,) def call(self, graph_module: torch.fx.GraphModule): diff --git a/backends/qualcomm/partition/common_defs.py b/backends/qualcomm/partition/common_defs.py index d68441c2f79..1c24d00390d 100644 --- a/backends/qualcomm/partition/common_defs.py +++ b/backends/qualcomm/partition/common_defs.py @@ -14,6 +14,7 @@ exir_ops.edge.aten.full.default, exir_ops.edge.aten.slice_scatter.default, exir_ops.edge.aten.copy.default, + exir_ops.edge.quantized_decomposed.embedding_4bit.dtype, ] to_be_implemented_operator = [ diff --git a/backends/qualcomm/quantizer/custom_annotation.py b/backends/qualcomm/quantizer/custom_annotation.py index 881d24bbb5e..db82172a9e2 100644 --- a/backends/qualcomm/quantizer/custom_annotation.py +++ b/backends/qualcomm/quantizer/custom_annotation.py @@ -12,6 +12,7 @@ QuantizationConfig, ) from executorch.backends.qualcomm.quantizer.utils import QUANT_ANNOTATION_KEY +from executorch.exir.dialects._ops import ops as exir_ops from torch.ao.quantization.quantizer import ( QuantizationAnnotation, SharedQuantizationSpec, @@ -144,3 +145,35 @@ def annotate_matmul(node: Node, quantization_config: QuantizationConfig): for node in gm.graph.nodes: if node.op == "call_function" and node.target == torch.ops.aten.matmul.default: annotate_matmul(node, quantization_config_16a8w) + + +def get_custom_quant_ios_dtype( + cache_shape: torch.Size, + node: torch.fx.Node, + kv_dtype=torch.uint8, + sharding_dtype=torch.uint16, +): + """ + This function is specific for llama inputs and outputs + """ + if node.op == "placeholder" and "attention_sdpa_kv_cache_past_" in node.name: + return kv_dtype + + # Tag index put node before copy node, because copy is a skipped node in qnn + if ( + exir_ops.edge.aten.index_put.default == node.target + and node.meta["val"].shape == cache_shape + ): + return kv_dtype + + # Tag sharding io + if exir_ops.edge.llama.fallback.default in [ + u.target for u in list(node.users.keys()) + ] + [node.target]: + return sharding_dtype + + # Tag index op as quantized tensors. It is caused by sharding + if exir_ops.edge.aten.index.Tensor in [ + u.target for u in list(node.users.keys()) + ] + [node.target]: + return sharding_dtype diff --git a/backends/qualcomm/utils/utils.py b/backends/qualcomm/utils/utils.py index 88a84f2f9a6..30e04750b58 100644 --- a/backends/qualcomm/utils/utils.py +++ b/backends/qualcomm/utils/utils.py @@ -71,6 +71,7 @@ QCOM_PASS_EXPAND_BROADCAST_SHAPE, QCOM_PASS_SKIP_ADVANCED_REQUANT, QCOM_QNN_COMPILE_SPEC, + QCOM_QUANTIZED_IO, ) from executorch.exir import ExirExportedProgram @@ -876,3 +877,12 @@ def get_soc_to_chipset_map(): "SM8475": QcomChipset.SM8475, "SM8450": QcomChipset.SM8450, } + + +def tag_quant_io(gm: torch.fx.GraphModule, get_quant_io_dtype_fn: Callable): + """ + Tag io nodes which get/output quantized tensor. No need to insert q/dq in qnn_preprocess + """ + for node in gm.graph.nodes: + if dtype := get_quant_io_dtype_fn(node): + node.meta[QCOM_QUANTIZED_IO] = dtype diff --git a/examples/models/llama/export_llama.py b/examples/models/llama/export_llama.py index 3d0d1b7bcfb..1899ccf4df6 100644 --- a/examples/models/llama/export_llama.py +++ b/examples/models/llama/export_llama.py @@ -7,11 +7,14 @@ # Example script for exporting Llama2 to flatbuffer import logging +import sys import torch from .export_llama_lib import build_args_parser, export_llama +sys.setrecursionlimit(4096) + FORMAT = "[%(levelname)s %(asctime)s %(filename)s:%(lineno)s] %(message)s" logging.basicConfig(level=logging.INFO, format=FORMAT) diff --git a/examples/models/llama/export_llama_lib.py b/examples/models/llama/export_llama_lib.py index 940bcaecbc7..5511f553875 100644 --- a/examples/models/llama/export_llama_lib.py +++ b/examples/models/llama/export_llama_lib.py @@ -50,6 +50,8 @@ fuse_layer_norms, get_model_with_r1_r2, ) + +from .source_transformation.attention import replace_attention_to_attention_sha from .source_transformation.quantize import ( get_quant_embedding_transform, get_quant_weight_transform, @@ -174,6 +176,12 @@ def build_args_parser() -> argparse.ArgumentParser: help="checkpoint directory. Use with a sharded checkpoint, not for the standard llama2 model. Note, checkpoint_dir takes precedence over checkpoint if both are set.", ) + parser.add_argument( + "--use_qnn_sha", + action="store_true", + help="Change multi head attention to multiple single head attention for qnn backend (Qualcomm)", + ) + parser.add_argument( "--calibration_tasks", nargs="+", @@ -642,7 +650,7 @@ def _export_llama(modelname, args) -> LLMEdgeManager: # noqa: C901 ) ) # pyre-ignore: Undefined import [21]: Could not find a module corresponding to import `executorch.backends.qualcomm.utils.utils` - from executorch.backends.qualcomm.utils.utils import _transform + from executorch.backends.qualcomm.utils.utils import _transform, tag_quant_io # pyre-ignore: Undefined attribute [16]: Module `executorch.backends` has no attribute `qualcomm`, Optional type has no attribute `exported_program` _transform(builder_exported_to_edge.edge_manager.exported_program()) @@ -654,7 +662,32 @@ def _export_llama(modelname, args) -> LLMEdgeManager: # noqa: C901 builder_exported_to_edge.metadata["get_n_layers"], shares=args.num_sharding, ) + from functools import partial + from executorch.backends.qualcomm.quantizer.custom_annotation import ( + get_custom_quant_ios_dtype, + ) + atten = builder_exported_to_edge.model.layers[0].attention + if args.use_qnn_sha: + cache_shape = torch.Size( + (atten.max_batch_size, atten.max_seq_len, atten.head_dim) + ) + else: + cache_shape = torch.Size( + ( + atten.max_batch_size, + atten.max_seq_len, + atten.n_kv_heads, + atten.head_dim, + ) + ) + tag_quant_io( + builder_exported_to_edge.edge_manager.exported_program().graph_module, + partial( + get_custom_quant_ios_dtype, + cache_shape, + ), + ) logging.info("Lowering model using following partitioner(s): ") for partitioner in partitioners: logging.info(f"--> {partitioner.__class__.__name__}") @@ -919,15 +952,27 @@ def _get_source_transforms( # noqa convert_linear_to_conv2d, ) - transforms.append(replace_kv_cache_with_simple_kv_cache) - transforms.append(replace_sdpa_with_flex_sdpa) - transforms.append(replace_causal_mask) - transforms.append(replace_rms_norm_with_native_rms_norm) - if args.optimized_rotation_path: - transforms.append(fuse_layer_norms) - transforms.append(get_model_with_r1_r2(args.optimized_rotation_path)) - # pyre-fixme[16]: Module `backends` has no attribute `qualcomm`. - transforms.append(convert_linear_to_conv2d) + if args.use_qnn_sha: + if args.optimized_rotation_path: + transforms.append(fuse_layer_norms) + transforms.append( + get_model_with_r1_r2(args.optimized_rotation_path) + ) + transforms.append(replace_attention_to_attention_sha) + transforms.append(replace_causal_mask) + transforms.append(replace_rms_norm_with_native_rms_norm) + transforms.append(convert_linear_to_conv2d) + else: + transforms.append(replace_kv_cache_with_simple_kv_cache) + transforms.append(replace_sdpa_with_flex_sdpa) + transforms.append(replace_causal_mask) + transforms.append(replace_rms_norm_with_native_rms_norm) + if args.optimized_rotation_path: + transforms.append(fuse_layer_norms) + transforms.append( + get_model_with_r1_r2(args.optimized_rotation_path) + ) + transforms.append(convert_linear_to_conv2d) elif args.mps: # Currently mps doesn't support sdpa op, use the simpler decomposition diff --git a/examples/models/llama/llama_transformer.py b/examples/models/llama/llama_transformer.py index 3c4e3f13e6f..d5ed038ade6 100644 --- a/examples/models/llama/llama_transformer.py +++ b/examples/models/llama/llama_transformer.py @@ -263,21 +263,22 @@ class Attention(nn.Module): def __init__(self, args: ModelArgs, layer_id: int): super().__init__() self.use_kv_cache = args.use_kv_cache - self.n_kv_heads = args.n_heads if args.n_kv_heads is None else args.n_kv_heads - assert args.n_heads % self.n_kv_heads == 0 + self.n_heads = args.n_heads + self.n_kv_heads = self.n_heads if args.n_kv_heads is None else args.n_kv_heads + assert self.n_heads % self.n_kv_heads == 0 model_parallel_size = 1 - self.n_local_heads = args.n_heads // model_parallel_size + self.n_local_heads = self.n_heads // model_parallel_size self.n_local_kv_heads = self.n_kv_heads // model_parallel_size self.n_rep = self.n_local_heads // self.n_local_kv_heads - self.head_dim = args.dim // args.n_heads + self.head_dim = args.dim // self.n_heads self.max_batch_size = args.max_batch_size self.max_seq_len = args.max_seq_len self.dim = args.dim - # args.dim = 4096, args.n_heads = 32, self.head_dim = 4096 / 32 = 125 - self.wq = nn.Linear(args.dim, args.n_heads * self.head_dim, bias=False) - self.wk = nn.Linear(args.dim, self.n_kv_heads * self.head_dim, bias=False) - self.wv = nn.Linear(args.dim, self.n_kv_heads * self.head_dim, bias=False) - self.wo = nn.Linear(args.n_heads * self.head_dim, args.dim, bias=False) + # args.dim = 4096, self.n_heads = 32, self.head_dim = 4096 / 32 = 125 + self.wq = nn.Linear(self.dim, self.n_heads * self.head_dim, bias=False) + self.wk = nn.Linear(self.dim, self.n_kv_heads * self.head_dim, bias=False) + self.wv = nn.Linear(self.dim, self.n_kv_heads * self.head_dim, bias=False) + self.wo = nn.Linear(self.n_heads * self.head_dim, self.dim, bias=False) self.layer_id = layer_id diff --git a/examples/models/llama/source_transformation/apply_spin_quant_r1_r2.py b/examples/models/llama/source_transformation/apply_spin_quant_r1_r2.py index e71007b1958..60bbad5598d 100644 --- a/examples/models/llama/source_transformation/apply_spin_quant_r1_r2.py +++ b/examples/models/llama/source_transformation/apply_spin_quant_r1_r2.py @@ -98,7 +98,7 @@ def get_model_with_r1_r2(optimized_rotation_path: str): def apply_spin_quant_r1_r2(model: torch.nn.Module, optimized_rotation_path: str): - optimized_rotation = torch.load(optimized_rotation_path, weights_only=True) + optimized_rotation = torch.load(optimized_rotation_path, weights_only=True, map_location=torch.device('cpu')) R1 = optimized_rotation["R1"].to(torch.float32) config = model.params num_heads = config.n_heads diff --git a/examples/models/llama/source_transformation/attention.py b/examples/models/llama/source_transformation/attention.py new file mode 100644 index 00000000000..59d989b7d92 --- /dev/null +++ b/examples/models/llama/source_transformation/attention.py @@ -0,0 +1,219 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +# pyre-unsafe + +# Example script for exporting Llama2 to flatbuffer + +import math +from typing import List, Optional, Tuple + +import torch +from executorch.examples.models.llama.llama_transformer import Attention +from torch import nn + + +def apply_rotary_emb_single( + x: torch.Tensor, freqs_cos: torch.Tensor, freqs_sin: torch.Tensor +) -> torch.Tensor: + x_r, x_i = x[..., ::2], x[..., 1::2] + + x_out_r = x_r * freqs_cos - x_i * freqs_sin + x_out_i = x_r * freqs_sin + x_i * freqs_cos + + x_out = torch.cat([x_out_r, x_out_i], dim=-1) + return x_out + + +class KVCacheSha(torch.nn.Module): + def __init__( + self, + max_batch_size: int, + max_seq_length: int, + n_heads: int, + head_dim: int, + dtype=torch.float32, + ): + super().__init__() + + # a buffer per head + cache_shape = (max_batch_size, max_seq_length, head_dim) + for i in range(n_heads): + self.register_buffer( + f"past_k_caches_{i}", + torch.zeros(cache_shape, dtype=dtype, device="cpu"), + persistent=False, + ) + self.register_buffer( + f"past_v_caches_{i}", + torch.zeros(cache_shape, dtype=dtype, device="cpu"), + persistent=False, + ) + + def update( + self, + input_pos: torch.Tensor, + k_val: torch.Tensor, + v_val: torch.Tensor, + cache_idx: int, + ) -> Tuple[torch.Tensor, torch.Tensor]: + new_k = torch.ops.aten.index_put_( + getattr(self, f"past_k_caches_{cache_idx}"), [None, input_pos], k_val + ) + new_v = torch.ops.aten.index_put_( + getattr(self, f"past_v_caches_{cache_idx}"), [None, input_pos], v_val + ) + return new_k, new_v + + def get_cache(self, head_idx): + return getattr(self, f"past_k_caches_{head_idx}"), getattr( + self, f"past_v_caches_{head_idx}" + ) + + +class SDPASha(torch.nn.Module): + + def __init__( + self, + max_batch_size: int, + max_seq_length: int, + n_heads: int, + n_rep: int, + head_dim: int, + dim: int, + ): + super().__init__() + self.head_dim = head_dim + self.n_rep = n_rep + self.dim = dim + self.kv_cache = KVCacheSha( + max_batch_size, max_seq_length, n_heads // n_rep, head_dim + ) + self.scale_factor = math.sqrt(head_dim) + + def forward( + self, + input_pos: torch.Tensor, + qs: List[torch.Tensor], + ks: List[torch.Tensor], + vs: List[torch.Tensor], + mask, + ): + + transpose_ks = [] + for i in range(len(ks)): + new_k, _ = self.kv_cache.update(input_pos, ks[i], vs[i], i) + transpose_ks.append(new_k.transpose(-2, -1).contiguous()) + + output = [] + for i, q in enumerate(qs): + cache_idx = i // self.n_rep + _, v = self.kv_cache.get_cache(cache_idx) + + attn_mask = mask[input_pos] + + attn_weight = q @ transpose_ks[cache_idx] / self.scale_factor + attn_weight += attn_mask + attn_weight = torch.softmax(attn_weight, dim=-1) + output.append(attn_weight @ v.contiguous()) + + return torch.cat(output, dim=-1) + + +class AttentionSha(nn.Module): + def __init__(self, attention_mha: nn.Module): + super().__init__() + if not attention_mha.use_kv_cache: + raise NotImplementedError("bert mode is not support") + + self.n_heads = attention_mha.n_heads + self.n_kv_heads = attention_mha.n_kv_heads + self.n_rep = self.n_heads // self.n_kv_heads + self.dim = attention_mha.dim + self.max_batch_size = attention_mha.max_batch_size + self.max_seq_len = attention_mha.max_seq_len + self.head_dim = attention_mha.dim // self.n_heads + self.SDPA = SDPASha( + self.max_batch_size, + self.max_seq_len, + self.n_heads, + self.n_rep, + self.head_dim, + self.dim, + ) + self.wq = nn.ModuleList( + [ + nn.Linear(self.dim, self.head_dim, bias=False) + for _ in range(self.n_heads) + ] + ) + self.wk = nn.ModuleList( + [ + nn.Linear(self.dim, self.head_dim, bias=False) + for _ in range(self.n_kv_heads) + ] + ) + self.wv = nn.ModuleList( + [ + nn.Linear(self.dim, self.head_dim, bias=False) + for _ in range(self.n_kv_heads) + ] + ) + + for i in range(self.n_heads): + self.wq[i].weight.data.copy_( + attention_mha.wq.weight[i * self.head_dim : (i + 1) * self.head_dim] + ) + for i in range(self.n_kv_heads): + self.wk[i].weight.data.copy_( + attention_mha.wk.weight[i * self.head_dim : (i + 1) * self.head_dim] + ) + self.wv[i].weight.data.copy_( + attention_mha.wv.weight[i * self.head_dim : (i + 1) * self.head_dim] + ) + self.wo = attention_mha.wo + + causal_mask = torch.tril( + torch.ones( + self.max_seq_len, + self.max_seq_len, + dtype=torch.bool, + device="cpu", + ) + ) + self.register_buffer("mask", causal_mask, persistent=False) + + def forward( + self, + x: torch.Tensor, + freqs_cos: torch.Tensor, + freqs_sin: torch.Tensor, + input_pos: Optional[torch.Tensor] = None, + ): + # QKV + q = [wq(x) for wq in self.wq] + k = [wk(x) for wk in self.wk] + v = [wv(x) for wv in self.wv] + for i in range(len(q)): + q[i] = apply_rotary_emb_single(q[i], freqs_cos, freqs_sin) + for i in range(len(k)): + k[i] = apply_rotary_emb_single(k[i], freqs_cos, freqs_sin) + + output = self.SDPA(input_pos, q, k, v, self.mask) + return self.wo(output) + + +def replace_attention_to_attention_sha(module: torch.nn.Module): + for name, child in module.named_children(): + if isinstance(child, Attention): + setattr( + module, + name, + AttentionSha(child), + ) + else: + replace_attention_to_attention_sha(child) + return module diff --git a/extension/llm/export/builder.py b/extension/llm/export/builder.py index bd12c374b51..ec258f81a02 100644 --- a/extension/llm/export/builder.py +++ b/extension/llm/export/builder.py @@ -219,9 +219,7 @@ def pt2e_calibrate( from executorch.examples.models.llama.eval_llama_lib import ( GraphModuleEvalWrapper, ) - from executorch.examples.models.llama.evaluate import ( # pyre-ignore[21] - evaluate_model, - ) + from lm_eval.evaluator import simple_evaluate # pyre-ignore[21] except ImportError: raise ImportError( "Please install the llm eval dependency via examples/models/llama/install_requirements.sh" @@ -250,6 +248,7 @@ def calibrate_template( ) else: token_list.append(torch.argmax(logits[:], dim=-1).item()) + print("Calibration Result: ",tokenizer.decode(token_list)) calibrate_template( module=prepared_module, @@ -266,11 +265,14 @@ def calibrate_template( generate_full_logits=self.generate_full_logits, enable_dynamic_shape=self.enable_dynamic_shape, ) - eval_results = evaluate_model( - eval_wrapper, - calibration_tasks, - calibration_limit, - ) + + # Evaluate the model + with torch.no_grad(): + eval_results = simple_evaluate( + model=eval_wrapper, + tasks=calibration_tasks, + limit=calibration_limit, + ) for task, res in eval_results["results"].items(): print(f"{task}: {res}") From 30b31ac97982571e271566421c18ad3a1f5a22f2 Mon Sep 17 00:00:00 2001 From: shewu-quic Date: Tue, 22 Oct 2024 02:56:18 -0700 Subject: [PATCH 2/5] optimize scale_factor and feedforward and change int32 for torken to improve embedding op --- backends/qualcomm/_passes/i64_to_i32.py | 7 +-- backends/qualcomm/_passes/layout_transform.py | 1 + backends/qualcomm/utils/utils.py | 4 +- examples/models/llama/export_llama_lib.py | 3 + examples/models/llama/model.py | 2 +- .../llama/source_transformation/sdpa.py | 57 +++++++++++++++++-- extension/llm/export/builder.py | 2 +- 7 files changed, 62 insertions(+), 14 deletions(-) diff --git a/backends/qualcomm/_passes/i64_to_i32.py b/backends/qualcomm/_passes/i64_to_i32.py index d4ebc40c59e..1d2171cc37a 100644 --- a/backends/qualcomm/_passes/i64_to_i32.py +++ b/backends/qualcomm/_passes/i64_to_i32.py @@ -61,12 +61,7 @@ def _cast_to_int32(self, graph_module: torch.fx.GraphModule): to_dst_node.meta["val"] = node_val.to(torch.int32) # Replace usage of the src dtype result with the dst dtype result. - if n.name != "tokens": - n.replace_all_uses_with(to_dst_node) - else: - for user in n.users.copy(): - if user.name != "quantized_decomposed_embedding_4bit_dtype": - user.replace_input_with(n, to_dst_node) + n.replace_all_uses_with(to_dst_node) to_dst_node.args = (n,) def call(self, graph_module: torch.fx.GraphModule): diff --git a/backends/qualcomm/_passes/layout_transform.py b/backends/qualcomm/_passes/layout_transform.py index 829c11fda42..a26e06d9b27 100644 --- a/backends/qualcomm/_passes/layout_transform.py +++ b/backends/qualcomm/_passes/layout_transform.py @@ -62,6 +62,7 @@ class LayoutTransform(ExportPass): exir_ops.edge.aten.prelu.default, exir_ops.edge.aten.relu.default, exir_ops.edge.aten._softmax.default, # TODO: Need to find a new solution to do "axis_order" to transform axis. + exir_ops.edge.aten.sigmoid.default, exir_ops.edge.aten.sqrt.default, exir_ops.edge.aten.sub.Tensor, exir_ops.edge.aten.sum.dim_IntList, diff --git a/backends/qualcomm/utils/utils.py b/backends/qualcomm/utils/utils.py index 30e04750b58..7b7fff3e197 100644 --- a/backends/qualcomm/utils/utils.py +++ b/backends/qualcomm/utils/utils.py @@ -166,8 +166,8 @@ def __init__(self, weight, bias=None): super().__init__() use_bias = bias is not None self.conv = torch.nn.Conv2d( - in_channels=weight.shape[0], - out_channels=weight.shape[1], + in_channels=weight.shape[1], + out_channels=weight.shape[0], kernel_size=1, padding=0, bias=use_bias, diff --git a/examples/models/llama/export_llama_lib.py b/examples/models/llama/export_llama_lib.py index 5511f553875..469d44079c5 100644 --- a/examples/models/llama/export_llama_lib.py +++ b/examples/models/llama/export_llama_lib.py @@ -66,6 +66,7 @@ replace_causal_mask, replace_kv_cache_with_coreml_kv_cache, replace_kv_cache_with_simple_kv_cache, + replace_feedforward_to_conv2d, replace_sdpa_with_coreml_sdpa, replace_sdpa_with_custom_op, replace_sdpa_with_flex_sdpa, @@ -961,6 +962,7 @@ def _get_source_transforms( # noqa transforms.append(replace_attention_to_attention_sha) transforms.append(replace_causal_mask) transforms.append(replace_rms_norm_with_native_rms_norm) + transforms.append(replace_feedforward_to_conv2d) transforms.append(convert_linear_to_conv2d) else: transforms.append(replace_kv_cache_with_simple_kv_cache) @@ -972,6 +974,7 @@ def _get_source_transforms( # noqa transforms.append( get_model_with_r1_r2(args.optimized_rotation_path) ) + transforms.append(replace_feedforward_to_conv2d) transforms.append(convert_linear_to_conv2d) elif args.mps: diff --git a/examples/models/llama/model.py b/examples/models/llama/model.py index e6f39e0cad5..50db5ad98be 100644 --- a/examples/models/llama/model.py +++ b/examples/models/llama/model.py @@ -245,7 +245,7 @@ def get_example_inputs_kvcache_sdpa(self): else: return ( torch.tensor( - [[1]], dtype=torch.long + [[1]], dtype=torch.int32 ), # tokens, with kv cache our input token length is always just 1 token. torch.tensor( [0], dtype=torch.long diff --git a/examples/models/llama/source_transformation/sdpa.py b/examples/models/llama/source_transformation/sdpa.py index f8362648f32..9c74143646a 100644 --- a/examples/models/llama/source_transformation/sdpa.py +++ b/examples/models/llama/source_transformation/sdpa.py @@ -12,8 +12,9 @@ from typing import Tuple, Union import torch +import torch.nn.functional as F -from executorch.examples.models.llama.llama_transformer import KVCache, SDPA +from executorch.examples.models.llama.llama_transformer import KVCache, SDPA, FeedForward from executorch.examples.models.llama.source_transformation.quantized_kv_cache import ( QuantizedKVCache, ) @@ -171,12 +172,14 @@ def __init__( self, kv_cache: KVCache, dim: int, + head_dim: int, n_rep: int, ): super().__init__() self.kv_cache = kv_cache self.dim = dim self.n_rep = n_rep + self.scale_factor = math.sqrt(head_dim) def forward( self, @@ -195,8 +198,7 @@ def forward( v = repeat_kv(v, self.n_rep) attn_mask = mask[input_pos] - scale_factor = 1 / math.sqrt(q.size(-1)) - attn_weight = q @ k.transpose(-2, -1) * scale_factor + attn_weight = q @ k.transpose(-2, -1) / self.scale_factor attn_weight += attn_mask attn_weight = torch.softmax(attn_weight, dim=-1) y = attn_weight @ v @@ -223,7 +225,7 @@ def replace_sdpa_with_flex_sdpa(module: torch.nn.Module): setattr( module, name, - SDPAFlex(child.kv_cache, child.dim, child.n_rep), + SDPAFlex(child.kv_cache, child.dim, child.head_dim, child.n_rep), ) else: replace_sdpa_with_flex_sdpa(child) @@ -428,3 +430,50 @@ def replace_causal_mask(module: torch.nn.Module): for _, child in module.named_children(): replace_causal_mask(child) return module + +class FeedForwardConv2D(torch.nn.Module): + def __init__(self, w1: torch.nn.Linear, w2: torch.nn.Linear, w3: torch.nn.Linear): + super().__init__() + self.w1_conv = torch.nn.Conv2d( + in_channels=w1.weight.shape[1], + out_channels=w1.weight.shape[0], + kernel_size=1, + padding=0, + bias=False, + ) + self.w2_conv = torch.nn.Conv2d( + in_channels=w2.weight.shape[1], + out_channels=w2.weight.shape[0], + kernel_size=1, + padding=0, + bias=False, + ) + self.w3_conv = torch.nn.Conv2d( + in_channels=w3.weight.shape[1], + out_channels=w3.weight.shape[0], + kernel_size=1, + padding=0, + bias=False, + ) + + self.w1_conv.weight = torch.nn.Parameter(w1.weight.reshape(*w1.weight.shape, 1, 1)) + self.w2_conv.weight = torch.nn.Parameter(w2.weight.reshape(*w2.weight.shape, 1, 1)) + self.w3_conv.weight = torch.nn.Parameter(w3.weight.reshape(*w3.weight.shape, 1, 1)) + + + def forward(self, x): + rank = x.dim() + x = x.unsqueeze(-1) if rank == 3 else x.reshape(1, *x.shape, 1) + x = torch.transpose(x, 1, 2) + res = self.w2_conv(F.silu(self.w1_conv(x)) * self.w3_conv(x)) + res = torch.transpose(res, 1, 2) + res = res.squeeze(-1) if rank == 3 else res.reshape(*res.shape[1:3]) + return res + +def replace_feedforward_to_conv2d(module: torch.nn.Module): + for name, child in module.named_children(): + if isinstance(child, FeedForward): + setattr(module, name, FeedForwardConv2D(child.w1, child.w2, child.w3)) + else: + replace_feedforward_to_conv2d(child) + return module diff --git a/extension/llm/export/builder.py b/extension/llm/export/builder.py index ec258f81a02..67e7553f813 100644 --- a/extension/llm/export/builder.py +++ b/extension/llm/export/builder.py @@ -237,7 +237,7 @@ def calibrate_template( with torch.no_grad(): while token_list[-1] != tokenizer.eos_id and pos < max_len: logits = module( - torch.full((1, 1), token_list[pos]), + torch.full((1, 1), token_list[pos], dtype=torch.int32), torch.tensor((pos,)), ) pos += 1 From 3f188ffd555f2890ac481001802aa6effce0caea Mon Sep 17 00:00:00 2001 From: shewu-quic Date: Fri, 25 Oct 2024 00:15:54 -0700 Subject: [PATCH 3/5] delegated copy op --- backends/qualcomm/builders/__init__.py | 2 + backends/qualcomm/builders/op_copy.py | 91 +++++++++++++++++++ backends/qualcomm/partition/common_defs.py | 1 - .../qualcomm/quantizer/custom_annotation.py | 2 +- 4 files changed, 94 insertions(+), 2 deletions(-) mode change 100644 => 100755 backends/qualcomm/builders/__init__.py create mode 100755 backends/qualcomm/builders/op_copy.py mode change 100644 => 100755 backends/qualcomm/partition/common_defs.py mode change 100644 => 100755 backends/qualcomm/quantizer/custom_annotation.py diff --git a/backends/qualcomm/builders/__init__.py b/backends/qualcomm/builders/__init__.py old mode 100644 new mode 100755 index 74fd58a3ec3..eab5a56b385 --- a/backends/qualcomm/builders/__init__.py +++ b/backends/qualcomm/builders/__init__.py @@ -14,6 +14,7 @@ op_ceil, op_clamp, op_conv2d, + op_copy, op_depth_to_space, op_dequantize, op_div, @@ -70,6 +71,7 @@ op_ceil, op_clamp, op_conv2d, + op_copy, op_depth_to_space, op_dequantize, op_div, diff --git a/backends/qualcomm/builders/op_copy.py b/backends/qualcomm/builders/op_copy.py new file mode 100755 index 00000000000..9dfea25df41 --- /dev/null +++ b/backends/qualcomm/builders/op_copy.py @@ -0,0 +1,91 @@ +# Copyright (c) Qualcomm Innovation Center, Inc. +# All rights reserved +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. +from typing import Dict + +import executorch.backends.qualcomm.python.PyQnnWrapperAdaptor as PyQnnWrapper + +import torch +from executorch.backends.qualcomm.utils.constants import ( + QCOM_QUANT_ATTRS, + QCOM_SCALE, + QCOM_ZERO_POINT, +) +from executorch.exir.dialects._ops import ops as exir_ops + +from .node_visitor import NodeVisitor, register_node_visitor +from .qnn_constants import OpElementWiseAdd, QNN_OP_PACKAGE_NAME_QTI_AISW + + +@register_node_visitor +class Copy(NodeVisitor): + target = ["aten.copy.default"] + + def __init__(self, *args) -> None: + super().__init__(*args) + + def define_node( + self, + node: torch.fx.Node, + nodes_to_wrappers: Dict[torch.fx.Node, PyQnnWrapper.TensorWrapper], + ) -> PyQnnWrapper.PyQnnOpWrapper: + input_node = node.args[0] + input_tensor = self.get_tensor(input_node, node) + copy_inp_tensor_wrapper = self.define_tensor( + input_node, + input_tensor, + PyQnnWrapper.Qnn_TensorType_t.QNN_TENSOR_TYPE_NATIVE, + nodes_to_wrappers, + is_input_tensor=True, + ) + # 'graph', 'name', 'op', 'target', 'args', and 'kwargs' + zero_input_node = torch.fx.Node( + node.graph, + node.name + "_runtime_scalar", + "call_function", + exir_ops.edge.aten.scalar_tensor.default, + (), # args + {}, # kwargs + ) + zero_input_tensor = torch.tensor(0, dtype=input_tensor.dtype) + if quant_attrs := input_node.meta.get(QCOM_QUANT_ATTRS): + quant_attrs = quant_attrs.copy() + quant_attrs[QCOM_ZERO_POINT] = 0 + quant_attrs[QCOM_SCALE] = 1 + zero_input_node.meta[QCOM_QUANT_ATTRS] = quant_attrs + + + zero_tensor_wrapper = self.define_tensor( + zero_input_node, + zero_input_tensor, + PyQnnWrapper.Qnn_TensorType_t.QNN_TENSOR_TYPE_STATIC, + nodes_to_wrappers, + is_input_tensor=True, + ) + copy_input_tensors = [copy_inp_tensor_wrapper, zero_tensor_wrapper] + + if quant_attrs := input_node.meta.get(QCOM_QUANT_ATTRS): + quant_attrs = quant_attrs.copy() + # Because there is no output after convert_pt2e, the QCOM_QUANT_ATTRS of node is none + node.meta[QCOM_QUANT_ATTRS] = quant_attrs + output_tensor = self.get_tensor(node, node) + output_tensor_wrapper = self.define_tensor( + node, + output_tensor, + PyQnnWrapper.Qnn_TensorType_t.QNN_TENSOR_TYPE_NATIVE, + nodes_to_wrappers, + is_input_tensor=False, + ) + copy_output_tensors = [output_tensor_wrapper] + + copy_op = PyQnnWrapper.PyQnnOpWrapper( + node.name, + QNN_OP_PACKAGE_NAME_QTI_AISW, + OpElementWiseAdd.op_name, + ) + copy_op.AddInputTensors(copy_input_tensors) + copy_op.AddOutputTensors(copy_output_tensors) + + return copy_op diff --git a/backends/qualcomm/partition/common_defs.py b/backends/qualcomm/partition/common_defs.py old mode 100644 new mode 100755 index 1c24d00390d..8f490bf1ed3 --- a/backends/qualcomm/partition/common_defs.py +++ b/backends/qualcomm/partition/common_defs.py @@ -13,7 +13,6 @@ exir_ops.edge.aten.clone.default, exir_ops.edge.aten.full.default, exir_ops.edge.aten.slice_scatter.default, - exir_ops.edge.aten.copy.default, exir_ops.edge.quantized_decomposed.embedding_4bit.dtype, ] diff --git a/backends/qualcomm/quantizer/custom_annotation.py b/backends/qualcomm/quantizer/custom_annotation.py old mode 100644 new mode 100755 index db82172a9e2..d8c89d968ff --- a/backends/qualcomm/quantizer/custom_annotation.py +++ b/backends/qualcomm/quantizer/custom_annotation.py @@ -161,7 +161,7 @@ def get_custom_quant_ios_dtype( # Tag index put node before copy node, because copy is a skipped node in qnn if ( - exir_ops.edge.aten.index_put.default == node.target + exir_ops.edge.aten.copy.default == node.target and node.meta["val"].shape == cache_shape ): return kv_dtype From a603e485c30751c001568780d8576afedc7ad5c3 Mon Sep 17 00:00:00 2001 From: shewu-quic Date: Fri, 25 Oct 2024 15:19:41 +0800 Subject: [PATCH 4/5] support int32 token in runner and profiling --- .../qualcomm/runtime/QnnExecuTorchBackend.cpp | 25 ++++++++++++++++++- extension/llm/runner/text_prefiller.cpp | 4 +-- extension/llm/runner/text_token_generator.h | 2 +- runtime/executor/method.cpp | 13 +++++++++- 4 files changed, 39 insertions(+), 5 deletions(-) diff --git a/backends/qualcomm/runtime/QnnExecuTorchBackend.cpp b/backends/qualcomm/runtime/QnnExecuTorchBackend.cpp index 5a55df6da3f..f22a4b90035 100644 --- a/backends/qualcomm/runtime/QnnExecuTorchBackend.cpp +++ b/backends/qualcomm/runtime/QnnExecuTorchBackend.cpp @@ -11,6 +11,9 @@ #include #include #include +#include + +// #include namespace executorch { namespace backends { namespace qnn { @@ -26,6 +29,7 @@ using executorch::runtime::MemoryAllocator; using executorch::runtime::Result; // ========== Public method implementations ========================= constexpr const char* QNN_COMPILE_SPEC = "qnn_compile_spec"; +// static int hi = 0; Result QnnExecuTorchBackend::init( BackendInitContext& context, FreeableBuffer* processed, @@ -36,6 +40,11 @@ Result QnnExecuTorchBackend::init( qnn_context_blob.buffer = const_cast(processed->data()); qnn_context_blob.nbytes = processed->size(); + // std::string path_ = "model_"+std::to_string(hi)+".bin"; + // std::ofstream fout(path_, std::ios::binary); + // fout.write(static_cast(processed->data()), static_cast(processed->size())); + // fout.flush(); + // hi++; // convert CompileSpec to qnn ExecuTorch option for (auto& compile_spec : compile_specs) { @@ -180,11 +189,12 @@ Result QnnExecuTorchBackend::init( } return qnn_manager; } - +// static int qq = 0; Error QnnExecuTorchBackend::execute( BackendExecutionContext& context, DelegateHandle* handle, EValue** args) const { + auto begin = std::chrono::high_resolution_clock::now(); QnnManager* qnn_manager = static_cast(handle); std::vector> input_tensors = @@ -202,6 +212,14 @@ Error QnnExecuTorchBackend::execute( // update data ptr only should be fine input_tensors[i]->FillDataBuffer( args[i]->toTensor().const_data_ptr(), false /* copy_data */); + // if(qq < input_tensors.size()){ + // std::string path_ = "qinput_"+std::to_string(qq)+".raw"; + // std::ofstream fout(path_, std::ios::binary); + // fout.write(static_cast(args[i]->toTensor().const_data_ptr()), input_tensors[i]->GetBytes()); + // fout.flush(); + // qq++; + // } + } input_tensor_structs.push_back(input_tensors[i]->CloneTensorStruct()); } @@ -232,7 +250,12 @@ Error QnnExecuTorchBackend::execute( qnn_manager->ProfileExecuteData(context.event_tracer()) == Error::Ok, Internal, "Fail to profile graph"); + auto end = std::chrono::high_resolution_clock::now(); + auto elapsed = std::chrono::duration_cast(end - + begin); + QNN_EXECUTORCH_LOG_INFO( + "QNN Graph Execute Time in QnnExecuTorchBackend: %ld us", elapsed.count()); return Error::Ok; } diff --git a/extension/llm/runner/text_prefiller.cpp b/extension/llm/runner/text_prefiller.cpp index 705583d638b..2e0368411be 100644 --- a/extension/llm/runner/text_prefiller.cpp +++ b/extension/llm/runner/text_prefiller.cpp @@ -41,7 +41,7 @@ ::executorch::runtime::Result TextPrefiller::prefill( auto tokens = from_blob( prompt_tokens.data(), {1, num_prompt_tokens}, - exec_aten::ScalarType::Long); + exec_aten::ScalarType::Int); auto start_pos_tensor = from_blob(&start_pos, {1}, exec_aten::ScalarType::Long); @@ -60,7 +60,7 @@ ::executorch::runtime::Result TextPrefiller::prefill( cur_token = prompt_tokens[0]; // initialize tensor wrappers - auto tokens = from_blob(&cur_token, {1, 1}, exec_aten::ScalarType::Long); + auto tokens = from_blob(&cur_token, {1, 1}, exec_aten::ScalarType::Int); auto start_pos_tensor = from_blob(&start_pos, {1}, exec_aten::ScalarType::Long); diff --git a/extension/llm/runner/text_token_generator.h b/extension/llm/runner/text_token_generator.h index 62b924a57d8..ce453882d16 100644 --- a/extension/llm/runner/text_token_generator.h +++ b/extension/llm/runner/text_token_generator.h @@ -71,7 +71,7 @@ class ET_EXPERIMENTAL TextTokenGenerator { // initialize tensor wrappers auto tokens_managed = from_blob( - token_data.data(), token_shape, executorch::aten::ScalarType::Long); + token_data.data(), token_shape, executorch::aten::ScalarType::Int); auto start_pos_managed = from_blob(&pos, {1}, executorch::aten::ScalarType::Long); diff --git a/runtime/executor/method.cpp b/runtime/executor/method.cpp index a05d789a808..e3d63352d46 100644 --- a/runtime/executor/method.cpp +++ b/runtime/executor/method.cpp @@ -26,7 +26,8 @@ #include #include #include - +#include +#include namespace executorch { namespace runtime { @@ -1004,6 +1005,7 @@ ET_NODISCARD Error Method::get_inputs(EValue* input_evalues, size_t length) { } Error Method::execute_instruction() { + auto begin = std::chrono::high_resolution_clock::now(); auto& chain = chains_[step_state_.chain_idx]; auto instructions = chain.s_chain_->instructions(); @@ -1030,6 +1032,9 @@ Error Method::execute_instruction() { chain.kernels_[step_state_.instr_idx](context, args.data()); // We reset the temp_allocator after the switch statement err = context.failure_state(); + auto op_index = instruction->instr_args_as_KernelCall()->op_index(); + auto op = serialization_plan_->operators()->Get(op_index); + std::cout <<"run op"<name()->c_str()<(end - + begin); + std::cout << "instruction->instr_args_type()" << static_cast(instruction->instr_args_type()) << std::endl; + std::cout<< "delegates_[delegate_idx].Execute Time:" < Date: Wed, 30 Oct 2024 19:44:11 -0700 Subject: [PATCH 5/5] - Replace copy with reshape - Delegated mutable buffer in AOT - Manage mutable buffer at runtime --- backends/qualcomm/builders/node_visitor.py | 35 +++++++++++--- backends/qualcomm/builders/op_copy.py | 31 ++----------- backends/qualcomm/builders/utils.py | 37 +++++++++++++++ .../qualcomm/partition/qnn_partitioner.py | 18 +------- .../qualcomm/runtime/QnnExecuTorchBackend.cpp | 46 ++++++++----------- backends/qualcomm/runtime/QnnManager.cpp | 29 ++++++++++++ 6 files changed, 118 insertions(+), 78 deletions(-) diff --git a/backends/qualcomm/builders/node_visitor.py b/backends/qualcomm/builders/node_visitor.py index 514bc6efd78..e2f845d92df 100644 --- a/backends/qualcomm/builders/node_visitor.py +++ b/backends/qualcomm/builders/node_visitor.py @@ -36,6 +36,8 @@ get_parameter, is_graph_input, is_graph_output, + is_mutable_buffer_input, + is_mutable_buffer_output, is_parameter, ) @@ -214,7 +216,7 @@ def get_tensor_type( node: torch.fx.Node, tensor_type: PyQnnWrapper.Qnn_TensorType_t, ) -> PyQnnWrapper.Qnn_TensorType_t: - is_input = is_graph_input(node, self.edge_program) + is_input = is_graph_input(node, self.edge_program) or is_mutable_buffer_input(node, self.edge_program) is_output = is_graph_output(node) # handle logic for input/output tensors if is_input or is_output: @@ -245,6 +247,29 @@ def get_data_type( return QNN_TENSOR_TYPE_MAP[tensor.dtype] + def get_tensor_name( + self, + node: torch.fx.Node, + wrapper_idx: int = 0, + ): + tensor_name = f"{node.name}_{wrapper_idx}" + # The `input_{id}` is utilized for sorting at runtime. Due to multiple passes in qnn_preprocess, + # the input order between QNN and the original graph’s forward function may differ. + # The `mutbuf_{id}` is utilized for mapping I/O of mutable buffer at runtime. + # The `output_` is identified as the graph’s output at runtime to prevent confusion with per_tensor_dump. + if is_mutable_buffer_input(node, self.edge_program): + fqn = self.edge_program.graph_signature.inputs_to_buffers[node.target] + position_index = list(self.edge_program.graph_signature.buffers_to_mutate.values()).index(fqn) + tensor_name = f"input_{str(self.external_ids[node])}_mutbuf_{str(position_index)}_{tensor_name}" + elif is_graph_input(node, self.edge_program): + tensor_name = f"input_{str(self.external_ids[node])}_{tensor_name}" + elif is_mutable_buffer_output(node, self.edge_program): + position_index = list(self.edge_program.graph_signature.buffers_to_mutate.keys()).index(node.name) + tensor_name = f"output_mutbuf_{position_index}_{tensor_name}" + elif is_graph_output(node): + tensor_name = f"output_{tensor_name}" + return tensor_name + def define_custom_tensor_wrapper( self, node_name: str, @@ -305,11 +330,7 @@ def define_tensor( if cached := nodes_to_wrappers[node_name].get(wrapper_idx, None): return cached - tensor_name = f"{node.name}_{wrapper_idx}" - if is_graph_input(node, self.edge_program): - tensor_name = "input_" + str(self.external_ids[node]) + "_" + tensor_name - if is_graph_output(node): - tensor_name = "output_" + tensor_name + tensor_name = self.get_tensor_name(node, wrapper_idx) dims = [1] if len(tensor.size()) == 0 else tensor.size() tensor_type = self.get_tensor_type(node, tensor_type) quant_encoding, quant_configs = self.get_quant_encoding_conf( @@ -381,7 +402,7 @@ def generate_node_to_external_map( # The order in which we visit the placeholder node is same as the *args # order for the forward(*args) signature for this gm. Using the order of # the nodes as external_id to extract the right arg from *args at runtime - if is_graph_input(node, edge_program): + if is_graph_input(node, edge_program) or is_mutable_buffer_input(node, edge_program): node_to_external_map[node] = len(node_to_external_map) for node in edge_program.graph_module.graph.nodes: if is_graph_output(node): diff --git a/backends/qualcomm/builders/op_copy.py b/backends/qualcomm/builders/op_copy.py index 9dfea25df41..aadfc6df1f1 100755 --- a/backends/qualcomm/builders/op_copy.py +++ b/backends/qualcomm/builders/op_copy.py @@ -16,7 +16,7 @@ from executorch.exir.dialects._ops import ops as exir_ops from .node_visitor import NodeVisitor, register_node_visitor -from .qnn_constants import OpElementWiseAdd, QNN_OP_PACKAGE_NAME_QTI_AISW +from .qnn_constants import OpReshape, QNN_OP_PACKAGE_NAME_QTI_AISW @register_node_visitor @@ -31,7 +31,7 @@ def define_node( node: torch.fx.Node, nodes_to_wrappers: Dict[torch.fx.Node, PyQnnWrapper.TensorWrapper], ) -> PyQnnWrapper.PyQnnOpWrapper: - input_node = node.args[0] + input_node = node.args[1] input_tensor = self.get_tensor(input_node, node) copy_inp_tensor_wrapper = self.define_tensor( input_node, @@ -40,31 +40,8 @@ def define_node( nodes_to_wrappers, is_input_tensor=True, ) - # 'graph', 'name', 'op', 'target', 'args', and 'kwargs' - zero_input_node = torch.fx.Node( - node.graph, - node.name + "_runtime_scalar", - "call_function", - exir_ops.edge.aten.scalar_tensor.default, - (), # args - {}, # kwargs - ) - zero_input_tensor = torch.tensor(0, dtype=input_tensor.dtype) - if quant_attrs := input_node.meta.get(QCOM_QUANT_ATTRS): - quant_attrs = quant_attrs.copy() - quant_attrs[QCOM_ZERO_POINT] = 0 - quant_attrs[QCOM_SCALE] = 1 - zero_input_node.meta[QCOM_QUANT_ATTRS] = quant_attrs - - zero_tensor_wrapper = self.define_tensor( - zero_input_node, - zero_input_tensor, - PyQnnWrapper.Qnn_TensorType_t.QNN_TENSOR_TYPE_STATIC, - nodes_to_wrappers, - is_input_tensor=True, - ) - copy_input_tensors = [copy_inp_tensor_wrapper, zero_tensor_wrapper] + copy_input_tensors = [copy_inp_tensor_wrapper] if quant_attrs := input_node.meta.get(QCOM_QUANT_ATTRS): quant_attrs = quant_attrs.copy() @@ -83,7 +60,7 @@ def define_node( copy_op = PyQnnWrapper.PyQnnOpWrapper( node.name, QNN_OP_PACKAGE_NAME_QTI_AISW, - OpElementWiseAdd.op_name, + OpReshape.op_name, ) copy_op.AddInputTensors(copy_input_tensors) copy_op.AddOutputTensors(copy_output_tensors) diff --git a/backends/qualcomm/builders/utils.py b/backends/qualcomm/builders/utils.py index ede32a5e659..ec5b92176f1 100755 --- a/backends/qualcomm/builders/utils.py +++ b/backends/qualcomm/builders/utils.py @@ -75,6 +75,23 @@ def is_graph_input( return tensor.op == "placeholder" and not is_parameter(tensor, edge_program) +def is_mutable_buffer_input( + tensor: torch.fx.Node, edge_program: torch.export.ExportedProgram +) -> bool: + """ + Check if the given tensor is a mutable buffer input + + Args: + tensor: EdgeIR Tensor that is being checked for mutable buffer input + """ + if tensor.op == "placeholder" and is_buffer(edge_program, tensor): + fqn = edge_program.graph_signature.inputs_to_buffers[tensor.target] + # if the buffer is mutated then record that + if fqn in edge_program.graph_signature.buffers_to_mutate.values(): + return True + return False + + def is_graph_output(tensor: torch.fx.Node) -> bool: """ Check if the given tensor is used as a graph output @@ -91,6 +108,26 @@ def is_graph_output(tensor: torch.fx.Node) -> bool: return False +def is_mutable_buffer_output( + tensor: torch.fx.Node, edge_program: torch.export.ExportedProgram +) -> bool: + """ + Check if the given tensor is a mutable buffer output + + Args: + tensor: EdgeIR Tensor that is being checked for mutable buffer output + """ + for user in tensor.users.keys(): + # getitem node is skiped, check the op_skip_ops.py + if user.op == "output" or ( + user.target.__name__ == "getitem" and is_graph_output(user) + ): + # if the buffer is mutated then record that + if tensor.name in edge_program.graph_signature.buffers_to_mutate.keys(): + return True + return False + + def is_constant( tensor: torch.fx.Node, edge_program: torch.export.ExportedProgram ) -> bool: diff --git a/backends/qualcomm/partition/qnn_partitioner.py b/backends/qualcomm/partition/qnn_partitioner.py index 659bda517f0..38294452225 100644 --- a/backends/qualcomm/partition/qnn_partitioner.py +++ b/backends/qualcomm/partition/qnn_partitioner.py @@ -23,7 +23,7 @@ Partitioner, PartitionResult, ) -from executorch.exir.backend.utils import tag_constant_data +from executorch.exir.backend.utils import tag_constant_data, tag_mutated_buffer from torch.fx.passes.infra.partitioner import Partition from torch.fx.passes.operator_support import OperatorSupportBase @@ -136,27 +136,13 @@ def tag_nodes( node.meta["delegation_tag"] = delegation_tag self.partition_tags[delegation_tag] = self.delegation_spec - # need to take care of consumed constants - consumed_constants = ( - *edge_program.graph_signature.inputs_to_buffers, - *edge_program.graph_signature.inputs_to_parameters, - ) - for node in edge_program.graph_module.graph.nodes: - # find placeholders as lifted_constants - if node.op != "placeholder" or len(node.users) != 0: - continue - - if node.name in consumed_constants: - # does no harm to merge them into last partition, - # since they will all be removed in following stage - node.meta["delegation_tag"] = delegation_tag - # override def partition(self, edge_program: torch.export.ExportedProgram) -> PartitionResult: partitions = self.generate_partitions(edge_program) if len(partitions) != 0: self.tag_nodes(partitions, edge_program) tag_constant_data(edge_program) + tag_mutated_buffer(edge_program) for node in edge_program.graph_module.graph.nodes: if hasattr(node, "meta"): # pop certain keys in meta for not affecting the passes in compilation diff --git a/backends/qualcomm/runtime/QnnExecuTorchBackend.cpp b/backends/qualcomm/runtime/QnnExecuTorchBackend.cpp index f22a4b90035..7cb301374ad 100644 --- a/backends/qualcomm/runtime/QnnExecuTorchBackend.cpp +++ b/backends/qualcomm/runtime/QnnExecuTorchBackend.cpp @@ -13,7 +13,6 @@ #include #include -// #include namespace executorch { namespace backends { namespace qnn { @@ -29,7 +28,6 @@ using executorch::runtime::MemoryAllocator; using executorch::runtime::Result; // ========== Public method implementations ========================= constexpr const char* QNN_COMPILE_SPEC = "qnn_compile_spec"; -// static int hi = 0; Result QnnExecuTorchBackend::init( BackendInitContext& context, FreeableBuffer* processed, @@ -40,11 +38,6 @@ Result QnnExecuTorchBackend::init( qnn_context_blob.buffer = const_cast(processed->data()); qnn_context_blob.nbytes = processed->size(); - // std::string path_ = "model_"+std::to_string(hi)+".bin"; - // std::ofstream fout(path_, std::ios::binary); - // fout.write(static_cast(processed->data()), static_cast(processed->size())); - // fout.flush(); - // hi++; // convert CompileSpec to qnn ExecuTorch option for (auto& compile_spec : compile_specs) { @@ -189,7 +182,7 @@ Result QnnExecuTorchBackend::init( } return qnn_manager; } -// static int qq = 0; + Error QnnExecuTorchBackend::execute( BackendExecutionContext& context, DelegateHandle* handle, @@ -204,37 +197,34 @@ Error QnnExecuTorchBackend::execute( std::vector input_tensor_structs; std::vector output_tensor_structs; + int args_index = 0; input_tensor_structs.reserve(input_tensors.size()); - for (int i = 0; i < input_tensors.size(); ++i) { - if (qnn_manager->RegisterMem( - args[i]->toTensor().mutable_data_ptr(), input_tensors[i]) != - Error::Ok) { - // update data ptr only should be fine - input_tensors[i]->FillDataBuffer( - args[i]->toTensor().const_data_ptr(), false /* copy_data */); - // if(qq < input_tensors.size()){ - // std::string path_ = "qinput_"+std::to_string(qq)+".raw"; - // std::ofstream fout(path_, std::ios::binary); - // fout.write(static_cast(args[i]->toTensor().const_data_ptr()), input_tensors[i]->GetBytes()); - // fout.flush(); - // qq++; - // } - + for (const auto& input_tensor : input_tensors){ + if (input_tensor->GetName().find("mutbuf_") == std::string::npos){ + if (qnn_manager->RegisterMem( + args[args_index]->toTensor().mutable_data_ptr(), input_tensor) != + Error::Ok) { + // update data ptr only should be fine + input_tensor->FillDataBuffer( + args[args_index]->toTensor().const_data_ptr(), false /* copy_data */); + } + args_index++; } - input_tensor_structs.push_back(input_tensors[i]->CloneTensorStruct()); + + input_tensor_structs.push_back(input_tensor->CloneTensorStruct()); } - int output_index = input_tensors.size(); + for (const auto& output_tensor : output_tensors) { // pos=0 limits the search to the prefix - if (output_tensor->GetName().rfind("output_", 0) == 0) { + if (output_tensor->GetName().rfind("output_", 0) == 0 && output_tensor->GetName().find("mutbuf_") == std::string::npos) { void* mutable_data_ptr = - args[output_index]->toTensor().mutable_data_ptr(); + args[args_index]->toTensor().mutable_data_ptr(); if (qnn_manager->RegisterMem(mutable_data_ptr, output_tensor) != Error::Ok) { output_tensor->FillDataBuffer(mutable_data_ptr, false /* copy_data */); } - output_index++; + args_index++; } output_tensor_structs.push_back(output_tensor->CloneTensorStruct()); } diff --git a/backends/qualcomm/runtime/QnnManager.cpp b/backends/qualcomm/runtime/QnnManager.cpp index 9eeb6a8a016..e7189a98335 100644 --- a/backends/qualcomm/runtime/QnnManager.cpp +++ b/backends/qualcomm/runtime/QnnManager.cpp @@ -16,6 +16,7 @@ #include #include #include +#include namespace executorch { namespace backends { @@ -33,6 +34,16 @@ bool CompareExportedInput( return numA < numB; } +int ExtractMutableBufferNumber(const std::string& name) { + std::string prefix = "mutbuf_"; + size_t startPos = name.find(prefix); + if (startPos != std::string::npos) { + startPos += prefix.length(); + return std::stoi(name.substr(startPos)); + } + return -1; +} + QnnManager::~QnnManager() { backend_params_ptr_.reset(new BackendConfigParameters()); logger_.reset(); @@ -324,9 +335,20 @@ Error QnnManager::AllocateTensor() { std::vector output_tensors = backend_params_ptr_->qnn_context_ptr_->GetGraphOutputs(); + // Mapping memory address for the input and output of mutable buffer + std::unordered_map mutable_buffer_id_to_memory_map; + for (auto& tensor : input_tensors) { std::shared_ptr tensor_wrapper = CreateTensorWrapper(tensor); tensor_wrapper->UpdateQnnTensorMeta(tensor); + + int mutable_buffer_id = ExtractMutableBufferNumber(tensor_wrapper->GetName()); + if (mutable_buffer_id != -1){ + // Delegate maintain the memory for mutable buffer + tensor_wrapper->AllocateDataBuffer(); + mutable_buffer_id_to_memory_map[mutable_buffer_id] = tensor_wrapper->GetStaticTensorData(); + } + input_tensors_.emplace_back(std::move(tensor_wrapper)); } if (!options_->is_from_context_binary()) { @@ -347,6 +369,13 @@ Error QnnManager::AllocateTensor() { if (IsTensorDump()) { tensor_wrapper->AllocateDataBuffer(); } + + int mutable_buffer_id = ExtractMutableBufferNumber(tensor_wrapper->GetName()); + if(mutable_buffer_id!=-1 && mutable_buffer_id_to_memory_map.find(mutable_buffer_id) != mutable_buffer_id_to_memory_map.end()){ + // Fill the same memory for I/O of mutable buffer + tensor_wrapper->FillDataBuffer(mutable_buffer_id_to_memory_map[mutable_buffer_id], false /* copy_data */); + } + output_tensors_.emplace_back(std::move(tensor_wrapper)); } return Error::Ok;