From 35334a04aed8806f75a0328712fb02666919af8c Mon Sep 17 00:00:00 2001 From: haowhsu-quic Date: Mon, 19 Aug 2024 18:18:22 +0800 Subject: [PATCH] Make consumed lifted_constants as delegated nodes --- .../qualcomm/partition/qnn_partitioner.py | 21 +++++++++++++++++-- 1 file changed, 19 insertions(+), 2 deletions(-) diff --git a/backends/qualcomm/partition/qnn_partitioner.py b/backends/qualcomm/partition/qnn_partitioner.py index 86028d0d445..73dbede8ff6 100644 --- a/backends/qualcomm/partition/qnn_partitioner.py +++ b/backends/qualcomm/partition/qnn_partitioner.py @@ -141,18 +141,35 @@ def generate_partitions( op_support=self.op_support_checker, ) - def tag_nodes(self, partitions: List[Partition]) -> None: + def tag_nodes( + self, partitions: List[Partition], edge_program: torch.export.ExportedProgram + ) -> None: for partition in partitions: for node in partition.nodes: delegation_tag = f"qnn_{partition.id}" node.meta["delegation_tag"] = delegation_tag self.partition_tags[delegation_tag] = self.delegation_spec + # need to take care of consumed constants + consumed_constants = ( + *edge_program.graph_signature.inputs_to_buffers, + *edge_program.graph_signature.inputs_to_parameters, + ) + for node in edge_program.graph_module.graph.nodes: + # find placeholders as lifted_constants + if node.op != "placeholder" or len(node.users) != 0: + continue + + if node.name in consumed_constants: + # does no harm to merge them into last partition, + # since they will all be removed in following stage + node.meta["delegation_tag"] = delegation_tag + # override def partition(self, edge_program: torch.export.ExportedProgram) -> PartitionResult: partitions = self.generate_partitions(edge_program) if len(partitions) != 0: - self.tag_nodes(partitions) + self.tag_nodes(partitions, edge_program) tag_constant_data(edge_program) for node in edge_program.graph_module.graph.nodes: if hasattr(node, "meta"):