diff --git a/backends/qualcomm/partition/qnn_partitioner.py b/backends/qualcomm/partition/qnn_partitioner.py index 86028d0d445..73dbede8ff6 100644 --- a/backends/qualcomm/partition/qnn_partitioner.py +++ b/backends/qualcomm/partition/qnn_partitioner.py @@ -141,18 +141,35 @@ def generate_partitions( op_support=self.op_support_checker, ) - def tag_nodes(self, partitions: List[Partition]) -> None: + def tag_nodes( + self, partitions: List[Partition], edge_program: torch.export.ExportedProgram + ) -> None: for partition in partitions: for node in partition.nodes: delegation_tag = f"qnn_{partition.id}" node.meta["delegation_tag"] = delegation_tag self.partition_tags[delegation_tag] = self.delegation_spec + # need to take care of consumed constants + consumed_constants = ( + *edge_program.graph_signature.inputs_to_buffers, + *edge_program.graph_signature.inputs_to_parameters, + ) + for node in edge_program.graph_module.graph.nodes: + # find placeholders as lifted_constants + if node.op != "placeholder" or len(node.users) != 0: + continue + + if node.name in consumed_constants: + # does no harm to merge them into last partition, + # since they will all be removed in following stage + node.meta["delegation_tag"] = delegation_tag + # override def partition(self, edge_program: torch.export.ExportedProgram) -> PartitionResult: partitions = self.generate_partitions(edge_program) if len(partitions) != 0: - self.tag_nodes(partitions) + self.tag_nodes(partitions, edge_program) tag_constant_data(edge_program) for node in edge_program.graph_module.graph.nodes: if hasattr(node, "meta"):