From d8bbe70ef4143b7096193426486a903c8f420117 Mon Sep 17 00:00:00 2001 From: winskuo-quic Date: Wed, 5 Feb 2025 17:45:53 +0800 Subject: [PATCH 1/2] - Annotate input only for argmin. - Update argmin opbuilder so it outputs int64, aligning with Pytorch - Add a pass to cast argmin output to int32 since most OP in QNN does not support int64 --- backends/qualcomm/_passes/__init__.py | 6 +- .../{i64_to_i32.py => constant_i64_to_i32.py} | 5 +- backends/qualcomm/_passes/layout_transform.py | 1 - .../qualcomm/_passes/tensor_i64_to_i32.py | 128 ++++++++++++++++++ backends/qualcomm/_passes/utils.py | 6 +- backends/qualcomm/builders/op_argmin.py | 52 +++++-- backends/qualcomm/builders/utils.py | 4 +- backends/qualcomm/quantizer/annotators.py | 17 ++- backends/qualcomm/tests/models.py | 27 ++++ backends/qualcomm/tests/test_qnn_delegate.py | 26 +++- backends/qualcomm/utils/constants.py | 1 + backends/qualcomm/utils/utils.py | 11 +- 12 files changed, 257 insertions(+), 27 deletions(-) rename backends/qualcomm/_passes/{i64_to_i32.py => constant_i64_to_i32.py} (95%) create mode 100644 backends/qualcomm/_passes/tensor_i64_to_i32.py diff --git a/backends/qualcomm/_passes/__init__.py b/backends/qualcomm/_passes/__init__.py index de4b7ce2cc9..d03ba2f9a25 100644 --- a/backends/qualcomm/_passes/__init__.py +++ b/backends/qualcomm/_passes/__init__.py @@ -1,18 +1,19 @@ from .annotate_and_quant_scalar import AnnotateAndQuantScalar from .annotate_decomposed import AnnotateDecomposed from .annotate_quant_attrs import AnnotateQuantAttrs +from .constant_i64_to_i32 import ConstantI64toI32 from .convert_bmm_to_matmul import ConvertBmmToMatmul from .convert_interpolate_with_upsample2d import ConvertInterpolateWithUpsample2D from .convert_prelu import ConvertPReLU from .convert_to_linear import ConvertToLinear from .expand_broadcast_tensor_shape import ExpandBroadcastTensorShape from .fold_qdq import FoldQDQ -from .i64_to_i32 import I64toI32 from .layout_transform import LayoutTransform from .recompose_pixel_unshuffle import RecomposePixelUnshuffle from .recompose_rms_norm import RecomposeRmsNorm from .remove_redundancy import RemoveRedundancy from .replace_index_put_input import ReplaceIndexPutInput +from .tensor_i64_to_i32 import TensorI64toI32 __all__ = [ @@ -25,7 +26,8 @@ ConvertToLinear, ExpandBroadcastTensorShape, FoldQDQ, - I64toI32, + ConstantI64toI32, + TensorI64toI32, LayoutTransform, RecomposePixelUnshuffle, RecomposeRmsNorm, diff --git a/backends/qualcomm/_passes/i64_to_i32.py b/backends/qualcomm/_passes/constant_i64_to_i32.py similarity index 95% rename from backends/qualcomm/_passes/i64_to_i32.py rename to backends/qualcomm/_passes/constant_i64_to_i32.py index 29c747d1a1a..9b5178b386e 100644 --- a/backends/qualcomm/_passes/i64_to_i32.py +++ b/backends/qualcomm/_passes/constant_i64_to_i32.py @@ -12,9 +12,10 @@ from torch._subclasses.fake_tensor import FakeTensor -class I64toI32(ExportPass): +class ConstantI64toI32(ExportPass): """ Cast unsupported int64 datatype into int32. + This will only be applied on constant nodes such as weights. """ def __init__( @@ -22,7 +23,7 @@ def __init__( edge_program: torch.export.ExportedProgram, skip_node: FrozenSet[str] = frozenset(), ): - super(I64toI32, self).__init__() + super(ConstantI64toI32, self).__init__() self.edge_program = edge_program self.skip_node = skip_node # pyre-ignore[4] diff --git a/backends/qualcomm/_passes/layout_transform.py b/backends/qualcomm/_passes/layout_transform.py index 67f33873d44..ccc34d3a528 100644 --- a/backends/qualcomm/_passes/layout_transform.py +++ b/backends/qualcomm/_passes/layout_transform.py @@ -45,7 +45,6 @@ class LayoutTransform(ExportPass): layout_agnostic_ops = { exir_ops.edge.aten.abs.default, exir_ops.edge.aten.add.Tensor, - exir_ops.edge.aten.argmin.default, exir_ops.edge.aten.bmm.default, exir_ops.edge.aten.cat.default, exir_ops.edge.aten.ceil.default, diff --git a/backends/qualcomm/_passes/tensor_i64_to_i32.py b/backends/qualcomm/_passes/tensor_i64_to_i32.py new file mode 100644 index 00000000000..60178cbc1ba --- /dev/null +++ b/backends/qualcomm/_passes/tensor_i64_to_i32.py @@ -0,0 +1,128 @@ +# Copyright (c) Qualcomm Innovation Center, Inc. +# All rights reserved +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +import logging + +import torch +from executorch.backends.qualcomm.builders.utils import is_graph_output +from executorch.backends.qualcomm.utils.constants import QCOM_ORIG_DTYPE +from executorch.exir import ExirExportedProgram +from executorch.exir.dialects._ops import ops as exir_ops +from executorch.exir.pass_base import ExportPass, PassResult +from executorch.exir.program._program import _get_updated_graph_signature +from torch._subclasses.fake_tensor import FakeTensor + + +class TensorI64toI32(ExportPass): + """ + Insert a cast node to cast dtype from int64 to int32. + This will only be applied on fake tensors. + """ + + cast_ops = { + torch.ops.aten.argmin.default, + } + + def __init__(self, edge_program): + super(TensorI64toI32, self).__init__() + self.edge_program = edge_program + + # pyre-ignore[2] + def _is_tensor_of_dtype(self, node_val, dtype: torch.dtype) -> bool: + return isinstance(node_val, FakeTensor) and node_val.dtype == dtype + + def _cast_to_int32(self, core_ep: ExirExportedProgram): + copy_op = torch.ops.aten._to_copy.default + for n in core_ep.exported_program.graph.nodes: + # Keep track of original output dtype so we ensure the dtype of the graph is consistent with nn.Module + if is_graph_output(n): + if isinstance(n.meta["val"], tuple): + dtype_list = [tensor.dtype for tensor in n.meta["val"]] + n.meta[QCOM_ORIG_DTYPE] = dtype_list + else: + n.meta[QCOM_ORIG_DTYPE] = n.meta["val"].dtype + continue + if n.target in self.cast_ops: + node_val = n.meta["val"] + if self._is_tensor_of_dtype(node_val, torch.int64): + with core_ep.exported_program.graph.inserting_after(n): + users = list(n.users.keys()) + args = (n,) + cast_node = core_ep.exported_program.graph.create_node( + "call_function", + copy_op, + args, + {"dtype": torch.int32}, + ) + cast_node.meta["val"] = node_val.to(torch.int32) + cast_node.args = args + + for user in users: + user.replace_input_with(n, cast_node) + + core_ep.exported_program._graph_signature = _get_updated_graph_signature( + core_ep.exported_program._graph_signature, + core_ep.exported_program.graph_module, + ) + core_ep.exported_program._validate() + + def _preserve_output_dtype( + self, exported_program: torch.export.exported_program.ExportedProgram + ): + graph_module = exported_program.graph_module + copy_op = exir_ops.edge.aten._to_copy.default + for n in graph_module.graph.nodes: + if is_graph_output(n) and QCOM_ORIG_DTYPE in n.meta: + if isinstance(n.meta["val"], tuple): + for i, dtype in enumerate(n.meta[QCOM_ORIG_DTYPE]): + # TODO: Enable this in future to support OP such as topK + if n.meta["val"][i].dtype != dtype: + raise AssertionError( + "Multi output nodes currently don't support casting dtype back." + ) + elif n.meta["val"].dtype != n.meta[QCOM_ORIG_DTYPE]: + if n.meta[QCOM_ORIG_DTYPE] != torch.int64: + logging.warning( + "This pass is intended to maintain output as int64 when nn.Module outputs int64. Other dtype modification is detected. Please ensure this is desired." + ) + with graph_module.graph.inserting_after(n): + orig_dtype = n.meta[QCOM_ORIG_DTYPE] + node_val = n.meta["val"] + args = (n,) + users = list(n.users.keys()) + output_users = [ + user for user in users if user.target == "output" + ] + cast_node = graph_module.graph.create_node( + "call_function", + copy_op, + args, + {"dtype": orig_dtype}, + ) + cast_node.meta["val"] = node_val.to(orig_dtype) + cast_node.args = args + for user in output_users: + user.replace_input_with(n, cast_node) + + def call(self, graph_module: torch.fx.GraphModule): + # Stage 1: _cast_to_int32 + # We add to_copy after the desired operations during this stage because the data type only propagates before to_edge. + # If we don't add to_copy here but do it after to_edge, the next operation after to_copy() will still expect int64 as its output. + # Stage 2: _preserve_output_dtype + # We will tag the output dtype during stage 1, and we will ensure that if user expects int64 as output, + # we need to convert the output back to int64 if it is casted from int64->int32 during stage 1. + if isinstance(self.edge_program, ExirExportedProgram): + self._cast_to_int32(self.edge_program) + self.edge_program.exported_program.graph_module.recompile() + elif isinstance( + self.edge_program, torch.export.exported_program.ExportedProgram + ): + self._preserve_output_dtype(self.edge_program) + else: + raise AssertionError( + "Should be ExirExportedProgram at stage 1 and torch.export.exported_program.ExportedProgram at stage 2" + ) + return PassResult(graph_module, True) diff --git a/backends/qualcomm/_passes/utils.py b/backends/qualcomm/_passes/utils.py index a606a21c625..e9977280922 100755 --- a/backends/qualcomm/_passes/utils.py +++ b/backends/qualcomm/_passes/utils.py @@ -60,18 +60,19 @@ def get_passes_dependency_for_capture_program(): AnnotateAndQuantScalar, AnnotateDecomposed, AnnotateQuantAttrs, + ConstantI64toI32, ConvertBmmToMatmul, ConvertInterpolateWithUpsample2D, ConvertPReLU, ConvertToLinear, ExpandBroadcastTensorShape, FoldQDQ, - I64toI32, LayoutTransform, RecomposePixelUnshuffle, RecomposeRmsNorm, RemoveRedundancy, ReplaceIndexPutInput, + TensorI64toI32, ) return { @@ -81,7 +82,8 @@ def get_passes_dependency_for_capture_program(): ConvertPReLU: [RemoveRedundancy], ConvertBmmToMatmul: [ConvertToLinear], ConvertInterpolateWithUpsample2D: [RemoveRedundancy], - I64toI32: [RemoveRedundancy], + ConstantI64toI32: [RemoveRedundancy], + TensorI64toI32: [RemoveRedundancy], AnnotateQuantAttrs: [ RecomposePixelUnshuffle, RecomposeRmsNorm, diff --git a/backends/qualcomm/builders/op_argmin.py b/backends/qualcomm/builders/op_argmin.py index c09cb04f169..5630b02a5cc 100644 --- a/backends/qualcomm/builders/op_argmin.py +++ b/backends/qualcomm/builders/op_argmin.py @@ -10,8 +10,8 @@ import torch from executorch.backends.qualcomm.utils.constants import QCOM_AXIS_ORDER, QCOM_DATA -from .node_visitor import NodeVisitor, register_node_visitor -from .qnn_constants import OpArgmin, QNN_OP_PACKAGE_NAME_QTI_AISW +from .node_visitor import NodeVisitor, QNN_TENSOR_TYPE_MAP, register_node_visitor +from .qnn_constants import OpArgmin, OpCast, QNN_OP_PACKAGE_NAME_QTI_AISW @register_node_visitor @@ -26,8 +26,10 @@ def define_node( node: torch.fx.Node, nodes_to_wrappers: Dict[torch.fx.Node, PyQnnWrapper.TensorWrapper], ) -> PyQnnWrapper.PyQnnOpWrapper: + op_wrapper_list = [] input_node = node.args[0] input_tensor = self.get_tensor(input_node, node) + output_tensor = self.get_tensor(node, node) argmin_inp_tensor_wrapper = self.define_tensor( input_node, node, @@ -37,17 +39,25 @@ def define_node( ) argmin_input_tensors = [argmin_inp_tensor_wrapper] - output_tensor = self.get_tensor(node, node).to(torch.int32) # arg output is index, do not quantize it. node.meta.pop("quant_attrs", None) - output_tensor_wrapper = self.define_tensor( - node, - node, - output_tensor, - PyQnnWrapper.Qnn_TensorType_t.QNN_TENSOR_TYPE_NATIVE, - nodes_to_wrappers, + input_quant_encoding, input_quant_configs = self.get_quant_encoding_conf( + input_node, node ) - argmin_output_tensors = [output_tensor_wrapper] + + argmin_intermediate_tensor_wrapper = self.define_custom_tensor_wrapper( + node_name=node.name + "_cast", + tensor_type=PyQnnWrapper.Qnn_TensorType_t.QNN_TENSOR_TYPE_NATIVE, + dtype=QNN_TENSOR_TYPE_MAP[torch.int32], + quant_encoding=input_quant_encoding, + quant_configs=input_quant_configs, + dims=output_tensor.size(), + tensor=output_tensor, + is_fake_tensor=True, + nodes_to_wrappers=nodes_to_wrappers, + ) + + argmin_output_tensors = [argmin_intermediate_tensor_wrapper] dim = cast(int, node.args[1]) if dim < 0: @@ -77,4 +87,24 @@ def define_node( {QCOM_DATA: keep_dims}, ) - return argmin_op + op_wrapper_list.append(argmin_op) + + cast_op = PyQnnWrapper.PyQnnOpWrapper( + node.name + "_cast", + QNN_OP_PACKAGE_NAME_QTI_AISW, + OpCast.op_name, + ) + + output_tensor_wrapper = self.define_tensor( + node, + node, + output_tensor, + PyQnnWrapper.Qnn_TensorType_t.QNN_TENSOR_TYPE_NATIVE, + nodes_to_wrappers, + ) + + cast_op.AddInputTensors([argmin_intermediate_tensor_wrapper]) + cast_op.AddOutputTensors([output_tensor_wrapper]) + op_wrapper_list.append(cast_op) + + return op_wrapper_list diff --git a/backends/qualcomm/builders/utils.py b/backends/qualcomm/builders/utils.py index ede32a5e659..c82ebaf1bb3 100755 --- a/backends/qualcomm/builders/utils.py +++ b/backends/qualcomm/builders/utils.py @@ -75,14 +75,14 @@ def is_graph_input( return tensor.op == "placeholder" and not is_parameter(tensor, edge_program) -def is_graph_output(tensor: torch.fx.Node) -> bool: +def is_graph_output(node: torch.fx.Node) -> bool: """ Check if the given tensor is used as a graph output Args: tensor: EdgeIR Tensor that is being checked for graph input """ - for user in tensor.users.keys(): + for user in node.users.keys(): # getitem node is skiped, check the op_skip_ops.py if user.op == "output" or ( user.target.__name__ == "getitem" and is_graph_output(user) diff --git a/backends/qualcomm/quantizer/annotators.py b/backends/qualcomm/quantizer/annotators.py index c9f28ae760b..6b73b14c6ef 100644 --- a/backends/qualcomm/quantizer/annotators.py +++ b/backends/qualcomm/quantizer/annotators.py @@ -110,6 +110,21 @@ def annotate_in_out_obs_sharing_op( ) +def annotate_single_in(node: Node, quantization_config: QuantizationConfig) -> None: + if _is_annotated([node]): + return + + input_qspec_map = {} + input_act = node.args[0] + assert isinstance(input_act, Node) + input_qspec_map[input_act] = quantization_config.input_activation + + node.meta[QUANT_ANNOTATION_KEY] = QuantizationAnnotation( + input_qspec_map=input_qspec_map, + _annotated=True, + ) + + def annotate_single_in_single_out( node: Node, quantization_config: QuantizationConfig ) -> None: @@ -171,7 +186,7 @@ def annotate_add(node: Node, quantization_config: QuantizationConfig) -> None: def annotate_argmin(node: Node, quantization_config: QuantizationConfig) -> None: if _is_annotated([node]): return - annotate_single_in_single_out(node, quantization_config) + annotate_single_in(node, quantization_config) @register_annotator([torch.ops.aten.sub, torch.ops.aten.sub.Tensor]) diff --git a/backends/qualcomm/tests/models.py b/backends/qualcomm/tests/models.py index 6e758a5c45f..337925a581e 100644 --- a/backends/qualcomm/tests/models.py +++ b/backends/qualcomm/tests/models.py @@ -66,6 +66,33 @@ def forward(self, y): ) +class Argmin(torch.nn.Module): + def __init__(self): + super().__init__() + + def forward(self, x): + x = torch.argmin(x, dim=0, keepdim=True) + return x + + +class ArgminViewSqueezeConv2D(torch.nn.Module): + def __init__(self): + # This model is mainly to test the PASS TensorI64toI32 + super().__init__() + self.conv = torch.nn.Conv2d( + in_channels=3, out_channels=16, kernel_size=3, stride=1, padding=1 + ) + + def forward(self, x, y): + argmin_out = torch.argmin(x, dim=0, keepdim=True) + index_out = y[argmin_out] + conv_out = self.conv(index_out) + + view_out = argmin_out.view(-1) + squeeze_out = view_out.squeeze(-1) + return squeeze_out, conv_out + + class AvgPoolModule(torch.nn.Module): def __init__(self): super().__init__() diff --git a/backends/qualcomm/tests/test_qnn_delegate.py b/backends/qualcomm/tests/test_qnn_delegate.py index 9ceaf60c93d..4b6767c263c 100644 --- a/backends/qualcomm/tests/test_qnn_delegate.py +++ b/backends/qualcomm/tests/test_qnn_delegate.py @@ -113,7 +113,7 @@ def test_qnn_backend_arange(self): self.lower_module_and_test_output(module, sample_input) def test_qnn_backend_argmin(self): - module = Conv2dArgmin() # noqa: F405 + module = Argmin() # noqa: F405 sample_input = (torch.randn(16, 3, 4, 4),) self.lower_module_and_test_output(module, sample_input) @@ -705,6 +705,11 @@ def setUp(self): shared_buffer=TestQNN.shared_buffer, ) + def test_qnn_backend_argmin_view_squeeze_conv2d(self): + module = ArgminViewSqueezeConv2D() # noqa: F405 + sample_input = (torch.randn(32), torch.randn(32, 3, 32, 32)) + self.lower_module_and_test_output(module, sample_input) + def test_qnn_backend_chunk_add(self): module = ChunkAdd() # noqa: F405 torch.manual_seed(8) @@ -716,6 +721,11 @@ def test_qnn_backend_conv1d_relu_log_softmax(self): sample_input = (torch.rand(1, 2, 28),) self.lower_module_and_test_output(module, sample_input) + def test_qnn_backend_conv2d_argmin(self): + module = Conv2dArgmin() # noqa: F405 + sample_input = (torch.randn(16, 3, 4, 4),) + self.lower_module_and_test_output(module, sample_input) + def test_qnn_backend_conv2d_avg_pool2d(self): module = Conv2dAvgPool2d() # noqa: F405 sample_input = (torch.randn(16, 3, 16, 16),) @@ -945,7 +955,7 @@ def test_qnn_backend_arange(self): self.lower_module_and_test_output(module, sample_input) def test_qnn_backend_argmin(self): - module = Conv2dArgmin() # noqa: F405 + module = Argmin() # noqa: F405 sample_input = (torch.randn(16, 3, 4, 4),) module = self.get_qdq_module(module, sample_input) self.lower_module_and_test_output(module, sample_input) @@ -1616,6 +1626,12 @@ def setUp(self): shared_buffer=TestQNN.shared_buffer, ) + def test_qnn_backend_argmin_view_squeeze_conv2d(self): + module = ArgminViewSqueezeConv2D() # noqa: F405 + sample_input = (torch.randn(32), torch.randn(32, 3, 32, 32)) + module = self.get_qdq_module(module, sample_input) + self.lower_module_and_test_output(module, sample_input) + def test_qnn_backend_chunk_add(self): module = ChunkAdd() # noqa: F405 torch.manual_seed(8) @@ -1629,6 +1645,12 @@ def test_qnn_backend_conv1d_relu_log_softmax(self): module = self.get_qdq_module(module, sample_input) self.lower_module_and_test_output(module, sample_input) + def test_qnn_backend_conv2d_argmin(self): + module = Conv2dArgmin() # noqa: F405 + sample_input = (torch.randn(16, 3, 4, 4),) + module = self.get_qdq_module(module, sample_input) + self.lower_module_and_test_output(module, sample_input) + def test_qnn_backend_conv2d_avg_pool2d(self): module = Conv2dAvgPool2d() # noqa: F405 sample_input = (torch.randn(16, 3, 16, 16),) diff --git a/backends/qualcomm/utils/constants.py b/backends/qualcomm/utils/constants.py index 4f73d331ad5..c31e8d2f35d 100644 --- a/backends/qualcomm/utils/constants.py +++ b/backends/qualcomm/utils/constants.py @@ -16,6 +16,7 @@ QCOM_INSERTED_PERMUTE = "qnn_permute" QCOM_LAYOUT_CHANGE = "layout_change" QCOM_OFFSET = "offset" +QCOM_ORIG_DTYPE = "orig_dtype" QCOM_QUANTIZED_IO = "q_tensor_io" QCOM_QUANT_ATTRS = "quant_attrs" QCOM_QUANT_MIN = "quant_min" diff --git a/backends/qualcomm/utils/utils.py b/backends/qualcomm/utils/utils.py index f94e22c5306..f56be7782d7 100644 --- a/backends/qualcomm/utils/utils.py +++ b/backends/qualcomm/utils/utils.py @@ -21,6 +21,7 @@ ) from executorch.backends.qualcomm._passes.annotate_decomposed import AnnotateDecomposed from executorch.backends.qualcomm._passes.annotate_quant_attrs import AnnotateQuantAttrs +from executorch.backends.qualcomm._passes.constant_i64_to_i32 import ConstantI64toI32 from executorch.backends.qualcomm._passes.convert_binary_op_with_scalar import ( ConvertBinaryOpsWithScalar, ) @@ -36,7 +37,6 @@ ExpandBroadcastTensorShape, ) from executorch.backends.qualcomm._passes.fold_qdq import FoldQDQ -from executorch.backends.qualcomm._passes.i64_to_i32 import I64toI32 from executorch.backends.qualcomm._passes.layout_transform import LayoutTransform from executorch.backends.qualcomm._passes.recompose_pixel_unshuffle import ( RecomposePixelUnshuffle, @@ -46,6 +46,7 @@ from executorch.backends.qualcomm._passes.replace_index_put_input import ( ReplaceIndexPutInput, ) +from executorch.backends.qualcomm._passes.tensor_i64_to_i32 import TensorI64toI32 from executorch.backends.qualcomm._passes.utils import ( get_passes_dependency_for_capture_program, ) @@ -327,7 +328,8 @@ def get_capture_program_passes(): (ConvertPReLU, True), (ConvertBmmToMatmul, True), (ConvertInterpolateWithUpsample2D, True), - (I64toI32, True), + (ConstantI64toI32, True), + (TensorI64toI32, True), (AnnotateQuantAttrs, True), (AnnotateAndQuantScalar, True), (AnnotateDecomposed, True), @@ -411,9 +413,10 @@ def capture_program( # TODO: Should modify the scalar op in the op builder instead of # using transformation core_ep = ExirExportedProgram(decomposed_ep, False) - core_ep.transform(ConvertBinaryOpsWithScalar()) + core_ep.transform( + TensorI64toI32(edge_program=core_ep), ConvertBinaryOpsWithScalar() + ) edge_ep = core_ep.to_edge(qnn_edge_config()) - _transform(edge_ep.exported_program, passes_job) return edge_ep From 66929fd8db8282f6cc886ea3d61ef1aa6abf86cc Mon Sep 17 00:00:00 2001 From: winskuo-quic Date: Tue, 11 Feb 2025 11:19:24 +0800 Subject: [PATCH 2/2] Update library path --- examples/qualcomm/oss_scripts/llama/llama.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/examples/qualcomm/oss_scripts/llama/llama.py b/examples/qualcomm/oss_scripts/llama/llama.py index e575a3f5c48..33eb96f6af1 100755 --- a/examples/qualcomm/oss_scripts/llama/llama.py +++ b/examples/qualcomm/oss_scripts/llama/llama.py @@ -19,7 +19,7 @@ from multiprocessing.connection import Client import torch -from executorch.backends.qualcomm._passes.i64_to_i32 import I64toI32 +from executorch.backends.qualcomm._passes.constant_i64_to_i32 import ConstantI64toI32 from executorch.backends.qualcomm.partition.qnn_partitioner import QnnPartitioner @@ -553,9 +553,9 @@ def compile(args, pte_filename, tokenizer): llama_instance_list[i] = get_quant_embedding_transform(args)( llama_instance_list[i] ) - passes_job[I64toI32][QCOM_PASS_ARGS_KWARGS_DEFAULTS_KEY]["skip_node"] = { - "tokens" - } + passes_job[ConstantI64toI32][QCOM_PASS_ARGS_KWARGS_DEFAULTS_KEY][ + "skip_node" + ] = {"tokens"} llama_instance_list[i] = convert_linear_to_conv2d(llama_instance_list[i]) llama_instance_list[i] = SingleLlama( llama_instance_list[i].eval(), pte_filename