diff --git a/backends/qualcomm/_passes/layout_transform.py b/backends/qualcomm/_passes/layout_transform.py index 829c11fda42..a26e06d9b27 100644 --- a/backends/qualcomm/_passes/layout_transform.py +++ b/backends/qualcomm/_passes/layout_transform.py @@ -62,6 +62,7 @@ class LayoutTransform(ExportPass): exir_ops.edge.aten.prelu.default, exir_ops.edge.aten.relu.default, exir_ops.edge.aten._softmax.default, # TODO: Need to find a new solution to do "axis_order" to transform axis. + exir_ops.edge.aten.sigmoid.default, exir_ops.edge.aten.sqrt.default, exir_ops.edge.aten.sub.Tensor, exir_ops.edge.aten.sum.dim_IntList, diff --git a/backends/qualcomm/builders/__init__.py b/backends/qualcomm/builders/__init__.py old mode 100644 new mode 100755 index 74fd58a3ec3..eab5a56b385 --- a/backends/qualcomm/builders/__init__.py +++ b/backends/qualcomm/builders/__init__.py @@ -14,6 +14,7 @@ op_ceil, op_clamp, op_conv2d, + op_copy, op_depth_to_space, op_dequantize, op_div, @@ -70,6 +71,7 @@ op_ceil, op_clamp, op_conv2d, + op_copy, op_depth_to_space, op_dequantize, op_div, diff --git a/backends/qualcomm/builders/node_visitor.py b/backends/qualcomm/builders/node_visitor.py index 514bc6efd78..e2f845d92df 100644 --- a/backends/qualcomm/builders/node_visitor.py +++ b/backends/qualcomm/builders/node_visitor.py @@ -36,6 +36,8 @@ get_parameter, is_graph_input, is_graph_output, + is_mutable_buffer_input, + is_mutable_buffer_output, is_parameter, ) @@ -214,7 +216,7 @@ def get_tensor_type( node: torch.fx.Node, tensor_type: PyQnnWrapper.Qnn_TensorType_t, ) -> PyQnnWrapper.Qnn_TensorType_t: - is_input = is_graph_input(node, self.edge_program) + is_input = is_graph_input(node, self.edge_program) or is_mutable_buffer_input(node, self.edge_program) is_output = is_graph_output(node) # handle logic for input/output tensors if is_input or is_output: @@ -245,6 +247,29 @@ def get_data_type( return QNN_TENSOR_TYPE_MAP[tensor.dtype] + def get_tensor_name( + self, + node: torch.fx.Node, + wrapper_idx: int = 0, + ): + tensor_name = f"{node.name}_{wrapper_idx}" + # The `input_{id}` is utilized for sorting at runtime. Due to multiple passes in qnn_preprocess, + # the input order between QNN and the original graph’s forward function may differ. + # The `mutbuf_{id}` is utilized for mapping I/O of mutable buffer at runtime. + # The `output_` is identified as the graph’s output at runtime to prevent confusion with per_tensor_dump. + if is_mutable_buffer_input(node, self.edge_program): + fqn = self.edge_program.graph_signature.inputs_to_buffers[node.target] + position_index = list(self.edge_program.graph_signature.buffers_to_mutate.values()).index(fqn) + tensor_name = f"input_{str(self.external_ids[node])}_mutbuf_{str(position_index)}_{tensor_name}" + elif is_graph_input(node, self.edge_program): + tensor_name = f"input_{str(self.external_ids[node])}_{tensor_name}" + elif is_mutable_buffer_output(node, self.edge_program): + position_index = list(self.edge_program.graph_signature.buffers_to_mutate.keys()).index(node.name) + tensor_name = f"output_mutbuf_{position_index}_{tensor_name}" + elif is_graph_output(node): + tensor_name = f"output_{tensor_name}" + return tensor_name + def define_custom_tensor_wrapper( self, node_name: str, @@ -305,11 +330,7 @@ def define_tensor( if cached := nodes_to_wrappers[node_name].get(wrapper_idx, None): return cached - tensor_name = f"{node.name}_{wrapper_idx}" - if is_graph_input(node, self.edge_program): - tensor_name = "input_" + str(self.external_ids[node]) + "_" + tensor_name - if is_graph_output(node): - tensor_name = "output_" + tensor_name + tensor_name = self.get_tensor_name(node, wrapper_idx) dims = [1] if len(tensor.size()) == 0 else tensor.size() tensor_type = self.get_tensor_type(node, tensor_type) quant_encoding, quant_configs = self.get_quant_encoding_conf( @@ -381,7 +402,7 @@ def generate_node_to_external_map( # The order in which we visit the placeholder node is same as the *args # order for the forward(*args) signature for this gm. Using the order of # the nodes as external_id to extract the right arg from *args at runtime - if is_graph_input(node, edge_program): + if is_graph_input(node, edge_program) or is_mutable_buffer_input(node, edge_program): node_to_external_map[node] = len(node_to_external_map) for node in edge_program.graph_module.graph.nodes: if is_graph_output(node): diff --git a/backends/qualcomm/builders/op_copy.py b/backends/qualcomm/builders/op_copy.py new file mode 100755 index 00000000000..aadfc6df1f1 --- /dev/null +++ b/backends/qualcomm/builders/op_copy.py @@ -0,0 +1,68 @@ +# Copyright (c) Qualcomm Innovation Center, Inc. +# All rights reserved +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. +from typing import Dict + +import executorch.backends.qualcomm.python.PyQnnWrapperAdaptor as PyQnnWrapper + +import torch +from executorch.backends.qualcomm.utils.constants import ( + QCOM_QUANT_ATTRS, + QCOM_SCALE, + QCOM_ZERO_POINT, +) +from executorch.exir.dialects._ops import ops as exir_ops + +from .node_visitor import NodeVisitor, register_node_visitor +from .qnn_constants import OpReshape, QNN_OP_PACKAGE_NAME_QTI_AISW + + +@register_node_visitor +class Copy(NodeVisitor): + target = ["aten.copy.default"] + + def __init__(self, *args) -> None: + super().__init__(*args) + + def define_node( + self, + node: torch.fx.Node, + nodes_to_wrappers: Dict[torch.fx.Node, PyQnnWrapper.TensorWrapper], + ) -> PyQnnWrapper.PyQnnOpWrapper: + input_node = node.args[1] + input_tensor = self.get_tensor(input_node, node) + copy_inp_tensor_wrapper = self.define_tensor( + input_node, + input_tensor, + PyQnnWrapper.Qnn_TensorType_t.QNN_TENSOR_TYPE_NATIVE, + nodes_to_wrappers, + is_input_tensor=True, + ) + + copy_input_tensors = [copy_inp_tensor_wrapper] + + if quant_attrs := input_node.meta.get(QCOM_QUANT_ATTRS): + quant_attrs = quant_attrs.copy() + # Because there is no output after convert_pt2e, the QCOM_QUANT_ATTRS of node is none + node.meta[QCOM_QUANT_ATTRS] = quant_attrs + output_tensor = self.get_tensor(node, node) + output_tensor_wrapper = self.define_tensor( + node, + output_tensor, + PyQnnWrapper.Qnn_TensorType_t.QNN_TENSOR_TYPE_NATIVE, + nodes_to_wrappers, + is_input_tensor=False, + ) + copy_output_tensors = [output_tensor_wrapper] + + copy_op = PyQnnWrapper.PyQnnOpWrapper( + node.name, + QNN_OP_PACKAGE_NAME_QTI_AISW, + OpReshape.op_name, + ) + copy_op.AddInputTensors(copy_input_tensors) + copy_op.AddOutputTensors(copy_output_tensors) + + return copy_op diff --git a/backends/qualcomm/builders/utils.py b/backends/qualcomm/builders/utils.py index ede32a5e659..ec5b92176f1 100755 --- a/backends/qualcomm/builders/utils.py +++ b/backends/qualcomm/builders/utils.py @@ -75,6 +75,23 @@ def is_graph_input( return tensor.op == "placeholder" and not is_parameter(tensor, edge_program) +def is_mutable_buffer_input( + tensor: torch.fx.Node, edge_program: torch.export.ExportedProgram +) -> bool: + """ + Check if the given tensor is a mutable buffer input + + Args: + tensor: EdgeIR Tensor that is being checked for mutable buffer input + """ + if tensor.op == "placeholder" and is_buffer(edge_program, tensor): + fqn = edge_program.graph_signature.inputs_to_buffers[tensor.target] + # if the buffer is mutated then record that + if fqn in edge_program.graph_signature.buffers_to_mutate.values(): + return True + return False + + def is_graph_output(tensor: torch.fx.Node) -> bool: """ Check if the given tensor is used as a graph output @@ -91,6 +108,26 @@ def is_graph_output(tensor: torch.fx.Node) -> bool: return False +def is_mutable_buffer_output( + tensor: torch.fx.Node, edge_program: torch.export.ExportedProgram +) -> bool: + """ + Check if the given tensor is a mutable buffer output + + Args: + tensor: EdgeIR Tensor that is being checked for mutable buffer output + """ + for user in tensor.users.keys(): + # getitem node is skiped, check the op_skip_ops.py + if user.op == "output" or ( + user.target.__name__ == "getitem" and is_graph_output(user) + ): + # if the buffer is mutated then record that + if tensor.name in edge_program.graph_signature.buffers_to_mutate.keys(): + return True + return False + + def is_constant( tensor: torch.fx.Node, edge_program: torch.export.ExportedProgram ) -> bool: diff --git a/backends/qualcomm/partition/common_defs.py b/backends/qualcomm/partition/common_defs.py old mode 100644 new mode 100755 index d68441c2f79..8f490bf1ed3 --- a/backends/qualcomm/partition/common_defs.py +++ b/backends/qualcomm/partition/common_defs.py @@ -13,7 +13,7 @@ exir_ops.edge.aten.clone.default, exir_ops.edge.aten.full.default, exir_ops.edge.aten.slice_scatter.default, - exir_ops.edge.aten.copy.default, + exir_ops.edge.quantized_decomposed.embedding_4bit.dtype, ] to_be_implemented_operator = [ diff --git a/backends/qualcomm/partition/qnn_partitioner.py b/backends/qualcomm/partition/qnn_partitioner.py index 659bda517f0..38294452225 100644 --- a/backends/qualcomm/partition/qnn_partitioner.py +++ b/backends/qualcomm/partition/qnn_partitioner.py @@ -23,7 +23,7 @@ Partitioner, PartitionResult, ) -from executorch.exir.backend.utils import tag_constant_data +from executorch.exir.backend.utils import tag_constant_data, tag_mutated_buffer from torch.fx.passes.infra.partitioner import Partition from torch.fx.passes.operator_support import OperatorSupportBase @@ -136,27 +136,13 @@ def tag_nodes( node.meta["delegation_tag"] = delegation_tag self.partition_tags[delegation_tag] = self.delegation_spec - # need to take care of consumed constants - consumed_constants = ( - *edge_program.graph_signature.inputs_to_buffers, - *edge_program.graph_signature.inputs_to_parameters, - ) - for node in edge_program.graph_module.graph.nodes: - # find placeholders as lifted_constants - if node.op != "placeholder" or len(node.users) != 0: - continue - - if node.name in consumed_constants: - # does no harm to merge them into last partition, - # since they will all be removed in following stage - node.meta["delegation_tag"] = delegation_tag - # override def partition(self, edge_program: torch.export.ExportedProgram) -> PartitionResult: partitions = self.generate_partitions(edge_program) if len(partitions) != 0: self.tag_nodes(partitions, edge_program) tag_constant_data(edge_program) + tag_mutated_buffer(edge_program) for node in edge_program.graph_module.graph.nodes: if hasattr(node, "meta"): # pop certain keys in meta for not affecting the passes in compilation diff --git a/backends/qualcomm/quantizer/custom_annotation.py b/backends/qualcomm/quantizer/custom_annotation.py old mode 100644 new mode 100755 index 881d24bbb5e..d8c89d968ff --- a/backends/qualcomm/quantizer/custom_annotation.py +++ b/backends/qualcomm/quantizer/custom_annotation.py @@ -12,6 +12,7 @@ QuantizationConfig, ) from executorch.backends.qualcomm.quantizer.utils import QUANT_ANNOTATION_KEY +from executorch.exir.dialects._ops import ops as exir_ops from torch.ao.quantization.quantizer import ( QuantizationAnnotation, SharedQuantizationSpec, @@ -144,3 +145,35 @@ def annotate_matmul(node: Node, quantization_config: QuantizationConfig): for node in gm.graph.nodes: if node.op == "call_function" and node.target == torch.ops.aten.matmul.default: annotate_matmul(node, quantization_config_16a8w) + + +def get_custom_quant_ios_dtype( + cache_shape: torch.Size, + node: torch.fx.Node, + kv_dtype=torch.uint8, + sharding_dtype=torch.uint16, +): + """ + This function is specific for llama inputs and outputs + """ + if node.op == "placeholder" and "attention_sdpa_kv_cache_past_" in node.name: + return kv_dtype + + # Tag index put node before copy node, because copy is a skipped node in qnn + if ( + exir_ops.edge.aten.copy.default == node.target + and node.meta["val"].shape == cache_shape + ): + return kv_dtype + + # Tag sharding io + if exir_ops.edge.llama.fallback.default in [ + u.target for u in list(node.users.keys()) + ] + [node.target]: + return sharding_dtype + + # Tag index op as quantized tensors. It is caused by sharding + if exir_ops.edge.aten.index.Tensor in [ + u.target for u in list(node.users.keys()) + ] + [node.target]: + return sharding_dtype diff --git a/backends/qualcomm/runtime/QnnExecuTorchBackend.cpp b/backends/qualcomm/runtime/QnnExecuTorchBackend.cpp index 5a55df6da3f..7cb301374ad 100644 --- a/backends/qualcomm/runtime/QnnExecuTorchBackend.cpp +++ b/backends/qualcomm/runtime/QnnExecuTorchBackend.cpp @@ -11,6 +11,8 @@ #include #include #include +#include + namespace executorch { namespace backends { namespace qnn { @@ -185,6 +187,7 @@ Error QnnExecuTorchBackend::execute( BackendExecutionContext& context, DelegateHandle* handle, EValue** args) const { + auto begin = std::chrono::high_resolution_clock::now(); QnnManager* qnn_manager = static_cast(handle); std::vector> input_tensors = @@ -194,29 +197,34 @@ Error QnnExecuTorchBackend::execute( std::vector input_tensor_structs; std::vector output_tensor_structs; + int args_index = 0; input_tensor_structs.reserve(input_tensors.size()); - for (int i = 0; i < input_tensors.size(); ++i) { - if (qnn_manager->RegisterMem( - args[i]->toTensor().mutable_data_ptr(), input_tensors[i]) != - Error::Ok) { - // update data ptr only should be fine - input_tensors[i]->FillDataBuffer( - args[i]->toTensor().const_data_ptr(), false /* copy_data */); + for (const auto& input_tensor : input_tensors){ + if (input_tensor->GetName().find("mutbuf_") == std::string::npos){ + if (qnn_manager->RegisterMem( + args[args_index]->toTensor().mutable_data_ptr(), input_tensor) != + Error::Ok) { + // update data ptr only should be fine + input_tensor->FillDataBuffer( + args[args_index]->toTensor().const_data_ptr(), false /* copy_data */); + } + args_index++; } - input_tensor_structs.push_back(input_tensors[i]->CloneTensorStruct()); + + input_tensor_structs.push_back(input_tensor->CloneTensorStruct()); } - int output_index = input_tensors.size(); + for (const auto& output_tensor : output_tensors) { // pos=0 limits the search to the prefix - if (output_tensor->GetName().rfind("output_", 0) == 0) { + if (output_tensor->GetName().rfind("output_", 0) == 0 && output_tensor->GetName().find("mutbuf_") == std::string::npos) { void* mutable_data_ptr = - args[output_index]->toTensor().mutable_data_ptr(); + args[args_index]->toTensor().mutable_data_ptr(); if (qnn_manager->RegisterMem(mutable_data_ptr, output_tensor) != Error::Ok) { output_tensor->FillDataBuffer(mutable_data_ptr, false /* copy_data */); } - output_index++; + args_index++; } output_tensor_structs.push_back(output_tensor->CloneTensorStruct()); } @@ -232,7 +240,12 @@ Error QnnExecuTorchBackend::execute( qnn_manager->ProfileExecuteData(context.event_tracer()) == Error::Ok, Internal, "Fail to profile graph"); + auto end = std::chrono::high_resolution_clock::now(); + auto elapsed = std::chrono::duration_cast(end - + begin); + QNN_EXECUTORCH_LOG_INFO( + "QNN Graph Execute Time in QnnExecuTorchBackend: %ld us", elapsed.count()); return Error::Ok; } diff --git a/backends/qualcomm/runtime/QnnManager.cpp b/backends/qualcomm/runtime/QnnManager.cpp index 9eeb6a8a016..e7189a98335 100644 --- a/backends/qualcomm/runtime/QnnManager.cpp +++ b/backends/qualcomm/runtime/QnnManager.cpp @@ -16,6 +16,7 @@ #include #include #include +#include namespace executorch { namespace backends { @@ -33,6 +34,16 @@ bool CompareExportedInput( return numA < numB; } +int ExtractMutableBufferNumber(const std::string& name) { + std::string prefix = "mutbuf_"; + size_t startPos = name.find(prefix); + if (startPos != std::string::npos) { + startPos += prefix.length(); + return std::stoi(name.substr(startPos)); + } + return -1; +} + QnnManager::~QnnManager() { backend_params_ptr_.reset(new BackendConfigParameters()); logger_.reset(); @@ -324,9 +335,20 @@ Error QnnManager::AllocateTensor() { std::vector output_tensors = backend_params_ptr_->qnn_context_ptr_->GetGraphOutputs(); + // Mapping memory address for the input and output of mutable buffer + std::unordered_map mutable_buffer_id_to_memory_map; + for (auto& tensor : input_tensors) { std::shared_ptr tensor_wrapper = CreateTensorWrapper(tensor); tensor_wrapper->UpdateQnnTensorMeta(tensor); + + int mutable_buffer_id = ExtractMutableBufferNumber(tensor_wrapper->GetName()); + if (mutable_buffer_id != -1){ + // Delegate maintain the memory for mutable buffer + tensor_wrapper->AllocateDataBuffer(); + mutable_buffer_id_to_memory_map[mutable_buffer_id] = tensor_wrapper->GetStaticTensorData(); + } + input_tensors_.emplace_back(std::move(tensor_wrapper)); } if (!options_->is_from_context_binary()) { @@ -347,6 +369,13 @@ Error QnnManager::AllocateTensor() { if (IsTensorDump()) { tensor_wrapper->AllocateDataBuffer(); } + + int mutable_buffer_id = ExtractMutableBufferNumber(tensor_wrapper->GetName()); + if(mutable_buffer_id!=-1 && mutable_buffer_id_to_memory_map.find(mutable_buffer_id) != mutable_buffer_id_to_memory_map.end()){ + // Fill the same memory for I/O of mutable buffer + tensor_wrapper->FillDataBuffer(mutable_buffer_id_to_memory_map[mutable_buffer_id], false /* copy_data */); + } + output_tensors_.emplace_back(std::move(tensor_wrapper)); } return Error::Ok; diff --git a/backends/qualcomm/utils/utils.py b/backends/qualcomm/utils/utils.py index 88a84f2f9a6..7b7fff3e197 100644 --- a/backends/qualcomm/utils/utils.py +++ b/backends/qualcomm/utils/utils.py @@ -71,6 +71,7 @@ QCOM_PASS_EXPAND_BROADCAST_SHAPE, QCOM_PASS_SKIP_ADVANCED_REQUANT, QCOM_QNN_COMPILE_SPEC, + QCOM_QUANTIZED_IO, ) from executorch.exir import ExirExportedProgram @@ -165,8 +166,8 @@ def __init__(self, weight, bias=None): super().__init__() use_bias = bias is not None self.conv = torch.nn.Conv2d( - in_channels=weight.shape[0], - out_channels=weight.shape[1], + in_channels=weight.shape[1], + out_channels=weight.shape[0], kernel_size=1, padding=0, bias=use_bias, @@ -876,3 +877,12 @@ def get_soc_to_chipset_map(): "SM8475": QcomChipset.SM8475, "SM8450": QcomChipset.SM8450, } + + +def tag_quant_io(gm: torch.fx.GraphModule, get_quant_io_dtype_fn: Callable): + """ + Tag io nodes which get/output quantized tensor. No need to insert q/dq in qnn_preprocess + """ + for node in gm.graph.nodes: + if dtype := get_quant_io_dtype_fn(node): + node.meta[QCOM_QUANTIZED_IO] = dtype diff --git a/examples/models/llama/export_llama.py b/examples/models/llama/export_llama.py index 3d0d1b7bcfb..1899ccf4df6 100644 --- a/examples/models/llama/export_llama.py +++ b/examples/models/llama/export_llama.py @@ -7,11 +7,14 @@ # Example script for exporting Llama2 to flatbuffer import logging +import sys import torch from .export_llama_lib import build_args_parser, export_llama +sys.setrecursionlimit(4096) + FORMAT = "[%(levelname)s %(asctime)s %(filename)s:%(lineno)s] %(message)s" logging.basicConfig(level=logging.INFO, format=FORMAT) diff --git a/examples/models/llama/export_llama_lib.py b/examples/models/llama/export_llama_lib.py index 940bcaecbc7..469d44079c5 100644 --- a/examples/models/llama/export_llama_lib.py +++ b/examples/models/llama/export_llama_lib.py @@ -50,6 +50,8 @@ fuse_layer_norms, get_model_with_r1_r2, ) + +from .source_transformation.attention import replace_attention_to_attention_sha from .source_transformation.quantize import ( get_quant_embedding_transform, get_quant_weight_transform, @@ -64,6 +66,7 @@ replace_causal_mask, replace_kv_cache_with_coreml_kv_cache, replace_kv_cache_with_simple_kv_cache, + replace_feedforward_to_conv2d, replace_sdpa_with_coreml_sdpa, replace_sdpa_with_custom_op, replace_sdpa_with_flex_sdpa, @@ -174,6 +177,12 @@ def build_args_parser() -> argparse.ArgumentParser: help="checkpoint directory. Use with a sharded checkpoint, not for the standard llama2 model. Note, checkpoint_dir takes precedence over checkpoint if both are set.", ) + parser.add_argument( + "--use_qnn_sha", + action="store_true", + help="Change multi head attention to multiple single head attention for qnn backend (Qualcomm)", + ) + parser.add_argument( "--calibration_tasks", nargs="+", @@ -642,7 +651,7 @@ def _export_llama(modelname, args) -> LLMEdgeManager: # noqa: C901 ) ) # pyre-ignore: Undefined import [21]: Could not find a module corresponding to import `executorch.backends.qualcomm.utils.utils` - from executorch.backends.qualcomm.utils.utils import _transform + from executorch.backends.qualcomm.utils.utils import _transform, tag_quant_io # pyre-ignore: Undefined attribute [16]: Module `executorch.backends` has no attribute `qualcomm`, Optional type has no attribute `exported_program` _transform(builder_exported_to_edge.edge_manager.exported_program()) @@ -654,7 +663,32 @@ def _export_llama(modelname, args) -> LLMEdgeManager: # noqa: C901 builder_exported_to_edge.metadata["get_n_layers"], shares=args.num_sharding, ) + from functools import partial + from executorch.backends.qualcomm.quantizer.custom_annotation import ( + get_custom_quant_ios_dtype, + ) + atten = builder_exported_to_edge.model.layers[0].attention + if args.use_qnn_sha: + cache_shape = torch.Size( + (atten.max_batch_size, atten.max_seq_len, atten.head_dim) + ) + else: + cache_shape = torch.Size( + ( + atten.max_batch_size, + atten.max_seq_len, + atten.n_kv_heads, + atten.head_dim, + ) + ) + tag_quant_io( + builder_exported_to_edge.edge_manager.exported_program().graph_module, + partial( + get_custom_quant_ios_dtype, + cache_shape, + ), + ) logging.info("Lowering model using following partitioner(s): ") for partitioner in partitioners: logging.info(f"--> {partitioner.__class__.__name__}") @@ -919,15 +953,29 @@ def _get_source_transforms( # noqa convert_linear_to_conv2d, ) - transforms.append(replace_kv_cache_with_simple_kv_cache) - transforms.append(replace_sdpa_with_flex_sdpa) - transforms.append(replace_causal_mask) - transforms.append(replace_rms_norm_with_native_rms_norm) - if args.optimized_rotation_path: - transforms.append(fuse_layer_norms) - transforms.append(get_model_with_r1_r2(args.optimized_rotation_path)) - # pyre-fixme[16]: Module `backends` has no attribute `qualcomm`. - transforms.append(convert_linear_to_conv2d) + if args.use_qnn_sha: + if args.optimized_rotation_path: + transforms.append(fuse_layer_norms) + transforms.append( + get_model_with_r1_r2(args.optimized_rotation_path) + ) + transforms.append(replace_attention_to_attention_sha) + transforms.append(replace_causal_mask) + transforms.append(replace_rms_norm_with_native_rms_norm) + transforms.append(replace_feedforward_to_conv2d) + transforms.append(convert_linear_to_conv2d) + else: + transforms.append(replace_kv_cache_with_simple_kv_cache) + transforms.append(replace_sdpa_with_flex_sdpa) + transforms.append(replace_causal_mask) + transforms.append(replace_rms_norm_with_native_rms_norm) + if args.optimized_rotation_path: + transforms.append(fuse_layer_norms) + transforms.append( + get_model_with_r1_r2(args.optimized_rotation_path) + ) + transforms.append(replace_feedforward_to_conv2d) + transforms.append(convert_linear_to_conv2d) elif args.mps: # Currently mps doesn't support sdpa op, use the simpler decomposition diff --git a/examples/models/llama/llama_transformer.py b/examples/models/llama/llama_transformer.py index 3c4e3f13e6f..d5ed038ade6 100644 --- a/examples/models/llama/llama_transformer.py +++ b/examples/models/llama/llama_transformer.py @@ -263,21 +263,22 @@ class Attention(nn.Module): def __init__(self, args: ModelArgs, layer_id: int): super().__init__() self.use_kv_cache = args.use_kv_cache - self.n_kv_heads = args.n_heads if args.n_kv_heads is None else args.n_kv_heads - assert args.n_heads % self.n_kv_heads == 0 + self.n_heads = args.n_heads + self.n_kv_heads = self.n_heads if args.n_kv_heads is None else args.n_kv_heads + assert self.n_heads % self.n_kv_heads == 0 model_parallel_size = 1 - self.n_local_heads = args.n_heads // model_parallel_size + self.n_local_heads = self.n_heads // model_parallel_size self.n_local_kv_heads = self.n_kv_heads // model_parallel_size self.n_rep = self.n_local_heads // self.n_local_kv_heads - self.head_dim = args.dim // args.n_heads + self.head_dim = args.dim // self.n_heads self.max_batch_size = args.max_batch_size self.max_seq_len = args.max_seq_len self.dim = args.dim - # args.dim = 4096, args.n_heads = 32, self.head_dim = 4096 / 32 = 125 - self.wq = nn.Linear(args.dim, args.n_heads * self.head_dim, bias=False) - self.wk = nn.Linear(args.dim, self.n_kv_heads * self.head_dim, bias=False) - self.wv = nn.Linear(args.dim, self.n_kv_heads * self.head_dim, bias=False) - self.wo = nn.Linear(args.n_heads * self.head_dim, args.dim, bias=False) + # args.dim = 4096, self.n_heads = 32, self.head_dim = 4096 / 32 = 125 + self.wq = nn.Linear(self.dim, self.n_heads * self.head_dim, bias=False) + self.wk = nn.Linear(self.dim, self.n_kv_heads * self.head_dim, bias=False) + self.wv = nn.Linear(self.dim, self.n_kv_heads * self.head_dim, bias=False) + self.wo = nn.Linear(self.n_heads * self.head_dim, self.dim, bias=False) self.layer_id = layer_id diff --git a/examples/models/llama/model.py b/examples/models/llama/model.py index e6f39e0cad5..50db5ad98be 100644 --- a/examples/models/llama/model.py +++ b/examples/models/llama/model.py @@ -245,7 +245,7 @@ def get_example_inputs_kvcache_sdpa(self): else: return ( torch.tensor( - [[1]], dtype=torch.long + [[1]], dtype=torch.int32 ), # tokens, with kv cache our input token length is always just 1 token. torch.tensor( [0], dtype=torch.long diff --git a/examples/models/llama/source_transformation/apply_spin_quant_r1_r2.py b/examples/models/llama/source_transformation/apply_spin_quant_r1_r2.py index e71007b1958..60bbad5598d 100644 --- a/examples/models/llama/source_transformation/apply_spin_quant_r1_r2.py +++ b/examples/models/llama/source_transformation/apply_spin_quant_r1_r2.py @@ -98,7 +98,7 @@ def get_model_with_r1_r2(optimized_rotation_path: str): def apply_spin_quant_r1_r2(model: torch.nn.Module, optimized_rotation_path: str): - optimized_rotation = torch.load(optimized_rotation_path, weights_only=True) + optimized_rotation = torch.load(optimized_rotation_path, weights_only=True, map_location=torch.device('cpu')) R1 = optimized_rotation["R1"].to(torch.float32) config = model.params num_heads = config.n_heads diff --git a/examples/models/llama/source_transformation/attention.py b/examples/models/llama/source_transformation/attention.py new file mode 100644 index 00000000000..59d989b7d92 --- /dev/null +++ b/examples/models/llama/source_transformation/attention.py @@ -0,0 +1,219 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +# pyre-unsafe + +# Example script for exporting Llama2 to flatbuffer + +import math +from typing import List, Optional, Tuple + +import torch +from executorch.examples.models.llama.llama_transformer import Attention +from torch import nn + + +def apply_rotary_emb_single( + x: torch.Tensor, freqs_cos: torch.Tensor, freqs_sin: torch.Tensor +) -> torch.Tensor: + x_r, x_i = x[..., ::2], x[..., 1::2] + + x_out_r = x_r * freqs_cos - x_i * freqs_sin + x_out_i = x_r * freqs_sin + x_i * freqs_cos + + x_out = torch.cat([x_out_r, x_out_i], dim=-1) + return x_out + + +class KVCacheSha(torch.nn.Module): + def __init__( + self, + max_batch_size: int, + max_seq_length: int, + n_heads: int, + head_dim: int, + dtype=torch.float32, + ): + super().__init__() + + # a buffer per head + cache_shape = (max_batch_size, max_seq_length, head_dim) + for i in range(n_heads): + self.register_buffer( + f"past_k_caches_{i}", + torch.zeros(cache_shape, dtype=dtype, device="cpu"), + persistent=False, + ) + self.register_buffer( + f"past_v_caches_{i}", + torch.zeros(cache_shape, dtype=dtype, device="cpu"), + persistent=False, + ) + + def update( + self, + input_pos: torch.Tensor, + k_val: torch.Tensor, + v_val: torch.Tensor, + cache_idx: int, + ) -> Tuple[torch.Tensor, torch.Tensor]: + new_k = torch.ops.aten.index_put_( + getattr(self, f"past_k_caches_{cache_idx}"), [None, input_pos], k_val + ) + new_v = torch.ops.aten.index_put_( + getattr(self, f"past_v_caches_{cache_idx}"), [None, input_pos], v_val + ) + return new_k, new_v + + def get_cache(self, head_idx): + return getattr(self, f"past_k_caches_{head_idx}"), getattr( + self, f"past_v_caches_{head_idx}" + ) + + +class SDPASha(torch.nn.Module): + + def __init__( + self, + max_batch_size: int, + max_seq_length: int, + n_heads: int, + n_rep: int, + head_dim: int, + dim: int, + ): + super().__init__() + self.head_dim = head_dim + self.n_rep = n_rep + self.dim = dim + self.kv_cache = KVCacheSha( + max_batch_size, max_seq_length, n_heads // n_rep, head_dim + ) + self.scale_factor = math.sqrt(head_dim) + + def forward( + self, + input_pos: torch.Tensor, + qs: List[torch.Tensor], + ks: List[torch.Tensor], + vs: List[torch.Tensor], + mask, + ): + + transpose_ks = [] + for i in range(len(ks)): + new_k, _ = self.kv_cache.update(input_pos, ks[i], vs[i], i) + transpose_ks.append(new_k.transpose(-2, -1).contiguous()) + + output = [] + for i, q in enumerate(qs): + cache_idx = i // self.n_rep + _, v = self.kv_cache.get_cache(cache_idx) + + attn_mask = mask[input_pos] + + attn_weight = q @ transpose_ks[cache_idx] / self.scale_factor + attn_weight += attn_mask + attn_weight = torch.softmax(attn_weight, dim=-1) + output.append(attn_weight @ v.contiguous()) + + return torch.cat(output, dim=-1) + + +class AttentionSha(nn.Module): + def __init__(self, attention_mha: nn.Module): + super().__init__() + if not attention_mha.use_kv_cache: + raise NotImplementedError("bert mode is not support") + + self.n_heads = attention_mha.n_heads + self.n_kv_heads = attention_mha.n_kv_heads + self.n_rep = self.n_heads // self.n_kv_heads + self.dim = attention_mha.dim + self.max_batch_size = attention_mha.max_batch_size + self.max_seq_len = attention_mha.max_seq_len + self.head_dim = attention_mha.dim // self.n_heads + self.SDPA = SDPASha( + self.max_batch_size, + self.max_seq_len, + self.n_heads, + self.n_rep, + self.head_dim, + self.dim, + ) + self.wq = nn.ModuleList( + [ + nn.Linear(self.dim, self.head_dim, bias=False) + for _ in range(self.n_heads) + ] + ) + self.wk = nn.ModuleList( + [ + nn.Linear(self.dim, self.head_dim, bias=False) + for _ in range(self.n_kv_heads) + ] + ) + self.wv = nn.ModuleList( + [ + nn.Linear(self.dim, self.head_dim, bias=False) + for _ in range(self.n_kv_heads) + ] + ) + + for i in range(self.n_heads): + self.wq[i].weight.data.copy_( + attention_mha.wq.weight[i * self.head_dim : (i + 1) * self.head_dim] + ) + for i in range(self.n_kv_heads): + self.wk[i].weight.data.copy_( + attention_mha.wk.weight[i * self.head_dim : (i + 1) * self.head_dim] + ) + self.wv[i].weight.data.copy_( + attention_mha.wv.weight[i * self.head_dim : (i + 1) * self.head_dim] + ) + self.wo = attention_mha.wo + + causal_mask = torch.tril( + torch.ones( + self.max_seq_len, + self.max_seq_len, + dtype=torch.bool, + device="cpu", + ) + ) + self.register_buffer("mask", causal_mask, persistent=False) + + def forward( + self, + x: torch.Tensor, + freqs_cos: torch.Tensor, + freqs_sin: torch.Tensor, + input_pos: Optional[torch.Tensor] = None, + ): + # QKV + q = [wq(x) for wq in self.wq] + k = [wk(x) for wk in self.wk] + v = [wv(x) for wv in self.wv] + for i in range(len(q)): + q[i] = apply_rotary_emb_single(q[i], freqs_cos, freqs_sin) + for i in range(len(k)): + k[i] = apply_rotary_emb_single(k[i], freqs_cos, freqs_sin) + + output = self.SDPA(input_pos, q, k, v, self.mask) + return self.wo(output) + + +def replace_attention_to_attention_sha(module: torch.nn.Module): + for name, child in module.named_children(): + if isinstance(child, Attention): + setattr( + module, + name, + AttentionSha(child), + ) + else: + replace_attention_to_attention_sha(child) + return module diff --git a/examples/models/llama/source_transformation/sdpa.py b/examples/models/llama/source_transformation/sdpa.py index f8362648f32..9c74143646a 100644 --- a/examples/models/llama/source_transformation/sdpa.py +++ b/examples/models/llama/source_transformation/sdpa.py @@ -12,8 +12,9 @@ from typing import Tuple, Union import torch +import torch.nn.functional as F -from executorch.examples.models.llama.llama_transformer import KVCache, SDPA +from executorch.examples.models.llama.llama_transformer import KVCache, SDPA, FeedForward from executorch.examples.models.llama.source_transformation.quantized_kv_cache import ( QuantizedKVCache, ) @@ -171,12 +172,14 @@ def __init__( self, kv_cache: KVCache, dim: int, + head_dim: int, n_rep: int, ): super().__init__() self.kv_cache = kv_cache self.dim = dim self.n_rep = n_rep + self.scale_factor = math.sqrt(head_dim) def forward( self, @@ -195,8 +198,7 @@ def forward( v = repeat_kv(v, self.n_rep) attn_mask = mask[input_pos] - scale_factor = 1 / math.sqrt(q.size(-1)) - attn_weight = q @ k.transpose(-2, -1) * scale_factor + attn_weight = q @ k.transpose(-2, -1) / self.scale_factor attn_weight += attn_mask attn_weight = torch.softmax(attn_weight, dim=-1) y = attn_weight @ v @@ -223,7 +225,7 @@ def replace_sdpa_with_flex_sdpa(module: torch.nn.Module): setattr( module, name, - SDPAFlex(child.kv_cache, child.dim, child.n_rep), + SDPAFlex(child.kv_cache, child.dim, child.head_dim, child.n_rep), ) else: replace_sdpa_with_flex_sdpa(child) @@ -428,3 +430,50 @@ def replace_causal_mask(module: torch.nn.Module): for _, child in module.named_children(): replace_causal_mask(child) return module + +class FeedForwardConv2D(torch.nn.Module): + def __init__(self, w1: torch.nn.Linear, w2: torch.nn.Linear, w3: torch.nn.Linear): + super().__init__() + self.w1_conv = torch.nn.Conv2d( + in_channels=w1.weight.shape[1], + out_channels=w1.weight.shape[0], + kernel_size=1, + padding=0, + bias=False, + ) + self.w2_conv = torch.nn.Conv2d( + in_channels=w2.weight.shape[1], + out_channels=w2.weight.shape[0], + kernel_size=1, + padding=0, + bias=False, + ) + self.w3_conv = torch.nn.Conv2d( + in_channels=w3.weight.shape[1], + out_channels=w3.weight.shape[0], + kernel_size=1, + padding=0, + bias=False, + ) + + self.w1_conv.weight = torch.nn.Parameter(w1.weight.reshape(*w1.weight.shape, 1, 1)) + self.w2_conv.weight = torch.nn.Parameter(w2.weight.reshape(*w2.weight.shape, 1, 1)) + self.w3_conv.weight = torch.nn.Parameter(w3.weight.reshape(*w3.weight.shape, 1, 1)) + + + def forward(self, x): + rank = x.dim() + x = x.unsqueeze(-1) if rank == 3 else x.reshape(1, *x.shape, 1) + x = torch.transpose(x, 1, 2) + res = self.w2_conv(F.silu(self.w1_conv(x)) * self.w3_conv(x)) + res = torch.transpose(res, 1, 2) + res = res.squeeze(-1) if rank == 3 else res.reshape(*res.shape[1:3]) + return res + +def replace_feedforward_to_conv2d(module: torch.nn.Module): + for name, child in module.named_children(): + if isinstance(child, FeedForward): + setattr(module, name, FeedForwardConv2D(child.w1, child.w2, child.w3)) + else: + replace_feedforward_to_conv2d(child) + return module diff --git a/extension/llm/export/builder.py b/extension/llm/export/builder.py index bd12c374b51..67e7553f813 100644 --- a/extension/llm/export/builder.py +++ b/extension/llm/export/builder.py @@ -219,9 +219,7 @@ def pt2e_calibrate( from executorch.examples.models.llama.eval_llama_lib import ( GraphModuleEvalWrapper, ) - from executorch.examples.models.llama.evaluate import ( # pyre-ignore[21] - evaluate_model, - ) + from lm_eval.evaluator import simple_evaluate # pyre-ignore[21] except ImportError: raise ImportError( "Please install the llm eval dependency via examples/models/llama/install_requirements.sh" @@ -239,7 +237,7 @@ def calibrate_template( with torch.no_grad(): while token_list[-1] != tokenizer.eos_id and pos < max_len: logits = module( - torch.full((1, 1), token_list[pos]), + torch.full((1, 1), token_list[pos], dtype=torch.int32), torch.tensor((pos,)), ) pos += 1 @@ -250,6 +248,7 @@ def calibrate_template( ) else: token_list.append(torch.argmax(logits[:], dim=-1).item()) + print("Calibration Result: ",tokenizer.decode(token_list)) calibrate_template( module=prepared_module, @@ -266,11 +265,14 @@ def calibrate_template( generate_full_logits=self.generate_full_logits, enable_dynamic_shape=self.enable_dynamic_shape, ) - eval_results = evaluate_model( - eval_wrapper, - calibration_tasks, - calibration_limit, - ) + + # Evaluate the model + with torch.no_grad(): + eval_results = simple_evaluate( + model=eval_wrapper, + tasks=calibration_tasks, + limit=calibration_limit, + ) for task, res in eval_results["results"].items(): print(f"{task}: {res}") diff --git a/extension/llm/runner/text_prefiller.cpp b/extension/llm/runner/text_prefiller.cpp index 705583d638b..2e0368411be 100644 --- a/extension/llm/runner/text_prefiller.cpp +++ b/extension/llm/runner/text_prefiller.cpp @@ -41,7 +41,7 @@ ::executorch::runtime::Result TextPrefiller::prefill( auto tokens = from_blob( prompt_tokens.data(), {1, num_prompt_tokens}, - exec_aten::ScalarType::Long); + exec_aten::ScalarType::Int); auto start_pos_tensor = from_blob(&start_pos, {1}, exec_aten::ScalarType::Long); @@ -60,7 +60,7 @@ ::executorch::runtime::Result TextPrefiller::prefill( cur_token = prompt_tokens[0]; // initialize tensor wrappers - auto tokens = from_blob(&cur_token, {1, 1}, exec_aten::ScalarType::Long); + auto tokens = from_blob(&cur_token, {1, 1}, exec_aten::ScalarType::Int); auto start_pos_tensor = from_blob(&start_pos, {1}, exec_aten::ScalarType::Long); diff --git a/extension/llm/runner/text_token_generator.h b/extension/llm/runner/text_token_generator.h index 62b924a57d8..ce453882d16 100644 --- a/extension/llm/runner/text_token_generator.h +++ b/extension/llm/runner/text_token_generator.h @@ -71,7 +71,7 @@ class ET_EXPERIMENTAL TextTokenGenerator { // initialize tensor wrappers auto tokens_managed = from_blob( - token_data.data(), token_shape, executorch::aten::ScalarType::Long); + token_data.data(), token_shape, executorch::aten::ScalarType::Int); auto start_pos_managed = from_blob(&pos, {1}, executorch::aten::ScalarType::Long); diff --git a/runtime/executor/method.cpp b/runtime/executor/method.cpp index a05d789a808..e3d63352d46 100644 --- a/runtime/executor/method.cpp +++ b/runtime/executor/method.cpp @@ -26,7 +26,8 @@ #include #include #include - +#include +#include namespace executorch { namespace runtime { @@ -1004,6 +1005,7 @@ ET_NODISCARD Error Method::get_inputs(EValue* input_evalues, size_t length) { } Error Method::execute_instruction() { + auto begin = std::chrono::high_resolution_clock::now(); auto& chain = chains_[step_state_.chain_idx]; auto instructions = chain.s_chain_->instructions(); @@ -1030,6 +1032,9 @@ Error Method::execute_instruction() { chain.kernels_[step_state_.instr_idx](context, args.data()); // We reset the temp_allocator after the switch statement err = context.failure_state(); + auto op_index = instruction->instr_args_as_KernelCall()->op_index(); + auto op = serialization_plan_->operators()->Get(op_index); + std::cout <<"run op"<name()->c_str()<(end - + begin); + std::cout << "instruction->instr_args_type()" << static_cast(instruction->instr_args_type()) << std::endl; + std::cout<< "delegates_[delegate_idx].Execute Time:" <