Skip to content

Commit 35334a0

Browse files
committed
Make consumed lifted_constants as delegated nodes
1 parent 6cb5726 commit 35334a0

File tree

1 file changed

+19
-2
lines changed

1 file changed

+19
-2
lines changed

backends/qualcomm/partition/qnn_partitioner.py

+19-2
Original file line numberDiff line numberDiff line change
@@ -141,18 +141,35 @@ def generate_partitions(
141141
op_support=self.op_support_checker,
142142
)
143143

144-
def tag_nodes(self, partitions: List[Partition]) -> None:
144+
def tag_nodes(
145+
self, partitions: List[Partition], edge_program: torch.export.ExportedProgram
146+
) -> None:
145147
for partition in partitions:
146148
for node in partition.nodes:
147149
delegation_tag = f"qnn_{partition.id}"
148150
node.meta["delegation_tag"] = delegation_tag
149151
self.partition_tags[delegation_tag] = self.delegation_spec
150152

153+
# need to take care of consumed constants
154+
consumed_constants = (
155+
*edge_program.graph_signature.inputs_to_buffers,
156+
*edge_program.graph_signature.inputs_to_parameters,
157+
)
158+
for node in edge_program.graph_module.graph.nodes:
159+
# find placeholders as lifted_constants
160+
if node.op != "placeholder" or len(node.users) != 0:
161+
continue
162+
163+
if node.name in consumed_constants:
164+
# does no harm to merge them into last partition,
165+
# since they will all be removed in following stage
166+
node.meta["delegation_tag"] = delegation_tag
167+
151168
# override
152169
def partition(self, edge_program: torch.export.ExportedProgram) -> PartitionResult:
153170
partitions = self.generate_partitions(edge_program)
154171
if len(partitions) != 0:
155-
self.tag_nodes(partitions)
172+
self.tag_nodes(partitions, edge_program)
156173
tag_constant_data(edge_program)
157174
for node in edge_program.graph_module.graph.nodes:
158175
if hasattr(node, "meta"):

0 commit comments

Comments
 (0)