@@ -141,18 +141,35 @@ def generate_partitions(
141
141
op_support = self .op_support_checker ,
142
142
)
143
143
144
- def tag_nodes (self , partitions : List [Partition ]) -> None :
144
+ def tag_nodes (
145
+ self , partitions : List [Partition ], edge_program : torch .export .ExportedProgram
146
+ ) -> None :
145
147
for partition in partitions :
146
148
for node in partition .nodes :
147
149
delegation_tag = f"qnn_{ partition .id } "
148
150
node .meta ["delegation_tag" ] = delegation_tag
149
151
self .partition_tags [delegation_tag ] = self .delegation_spec
150
152
153
+ # need to take care of consumed constants
154
+ consumed_constants = (
155
+ * edge_program .graph_signature .inputs_to_buffers ,
156
+ * edge_program .graph_signature .inputs_to_parameters ,
157
+ )
158
+ for node in edge_program .graph_module .graph .nodes :
159
+ # find placeholders as lifted_constants
160
+ if node .op != "placeholder" or len (node .users ) != 0 :
161
+ continue
162
+
163
+ if node .name in consumed_constants :
164
+ # does no harm to merge them into last partition,
165
+ # since they will all be removed in following stage
166
+ node .meta ["delegation_tag" ] = delegation_tag
167
+
151
168
# override
152
169
def partition (self , edge_program : torch .export .ExportedProgram ) -> PartitionResult :
153
170
partitions = self .generate_partitions (edge_program )
154
171
if len (partitions ) != 0 :
155
- self .tag_nodes (partitions )
172
+ self .tag_nodes (partitions , edge_program )
156
173
tag_constant_data (edge_program )
157
174
for node in edge_program .graph_module .graph .nodes :
158
175
if hasattr (node , "meta" ):
0 commit comments