6
6
7
7
import copy
8
8
import logging
9
- from contextlib import contextmanager
9
+ from contextlib import contextmanager , nullcontext
10
10
from functools import singledispatch
11
11
from typing import Generator , List
12
12
25
25
26
26
from executorch .exir .graph_module import get_control_flow_submodules
27
27
from executorch .exir .lowered_backend_module import (
28
- _get_new_signature ,
28
+ _unsafe_adjust_original_program ,
29
29
create_exported_program_from_submodule ,
30
30
create_submodule_from_nodes ,
31
31
LoweredBackendModule ,
32
32
)
33
- from executorch .exir .pass_base import ExportPass
34
33
from executorch .exir .program ._fake_program import (
35
34
get_fake_program ,
36
35
update_to_real_program ,
@@ -193,6 +192,7 @@ def _partition_and_lower_one_graph_module(
193
192
tagged_graph_module : torch .fx .GraphModule ,
194
193
partition_result : PartitionResult ,
195
194
owning_program : ExportedProgram ,
195
+ is_submodule : bool ,
196
196
) -> torch .fx .GraphModule :
197
197
"""
198
198
Partitioned and lowered the graph module based on the partition tag, this is to handle one graph module.
@@ -210,21 +210,40 @@ def _partition_and_lower_one_graph_module(
210
210
211
211
logging .debug (f"For tag { tag } , found nodes { node_list } " )
212
212
# Tag the nodes that are params as buffers, so we can order the submodule as (Parms + Buffers) (User Inputs)
213
- submodule , call_module_node = create_submodule_from_nodes (
214
- tagged_graph_module , node_list , tag
213
+
214
+ replace_ctx = (
215
+ tagged_graph_module ._set_replace_hook (
216
+ owning_program .graph_signature .get_replace_hook ()
217
+ )
218
+ if not is_submodule
219
+ else nullcontext ()
215
220
)
221
+ with replace_ctx :
222
+ submodule , call_module_node = create_submodule_from_nodes (
223
+ tagged_graph_module , node_list , tag
224
+ )
225
+
216
226
tagged_graph_module_output_node = [
217
227
node for node in tagged_graph_module .graph .nodes if node .op == "output"
218
- ]
228
+ ][ 0 ]
219
229
submodule_output_node = [
220
230
node for node in submodule .graph .nodes if node .op == "output"
221
- ]
222
- # Copy the output node meta from the original output node, because create_submodule_from_nodes doesn't cover the meta field
223
- submodule_output_node [0 ].meta = tagged_graph_module_output_node [0 ].meta
231
+ ][0 ]
232
+ # Copy the output node meta from the original output node, because
233
+ # create_submodule_from_nodes doesn't cover the meta field
234
+ submodule_output_node .meta = tagged_graph_module_output_node .meta
224
235
logging .debug (f"Partitioned graph module: { tagged_graph_module } " )
225
236
226
- submodule_program = create_exported_program_from_submodule (
227
- submodule , owning_program , tag
237
+ (
238
+ submodule_program ,
239
+ toplevel_input_specs_to_delete ,
240
+ toplevel_output_specs_to_delete ,
241
+ ) = create_exported_program_from_submodule (
242
+ submodule ,
243
+ owning_program ,
244
+ tag ,
245
+ call_module_node ,
246
+ is_submodule ,
228
247
)
229
248
230
249
lowered_submodule = to_backend (
@@ -257,64 +276,48 @@ def _partition_and_lower_one_graph_module(
257
276
call_delegate_node .meta ["debug_handle" ] = len (
258
277
tagged_graph_module .graph .nodes
259
278
)
279
+ call_delegate_node .meta ["val" ] = submodule_output_node .meta ["val" ]
260
280
call_module_node .replace_all_uses_with (call_delegate_node )
261
281
tagged_graph_module .graph .erase_node (call_module_node )
262
282
263
- # Delete all parameters/buffers consumed by the created exported program
264
- toplevel_signature = owning_program .graph_signature
265
- for node in tagged_graph_module .graph .nodes :
266
- # Find placeholders consumed by the delegate
267
- if node .op != "placeholder" or len (node .users ) != 0 :
268
- continue
269
-
270
- if node .name in toplevel_signature .inputs_to_buffers :
271
- # Delete the consumed buffers
272
- buffer_name = toplevel_signature .inputs_to_buffers .get (node .name )
273
- if buffer_name in owning_program .state_dict :
274
- owning_program .state_dict .pop (buffer_name )
275
- else :
276
- owning_program .constants .pop (buffer_name )
277
- tagged_graph_module .graph .erase_node (node )
278
- elif node .name in toplevel_signature .inputs_to_parameters :
279
- # Delete the consumed parameters
280
- param_name = toplevel_signature .inputs_to_parameters .get (node .name )
281
- owning_program .state_dict .pop (param_name )
282
- tagged_graph_module .graph .erase_node (node )
283
-
284
- tagged_graph_module .recompile ()
283
+ if is_submodule :
284
+ assert len (toplevel_input_specs_to_delete ) == 0
285
+ assert len (toplevel_output_specs_to_delete ) == 0
286
+ elif (
287
+ len (toplevel_input_specs_to_delete ) > 0
288
+ or len (toplevel_output_specs_to_delete ) > 0
289
+ ):
290
+ _unsafe_adjust_original_program (
291
+ owning_program ,
292
+ call_delegate_node ,
293
+ toplevel_input_specs_to_delete ,
294
+ toplevel_output_specs_to_delete ,
295
+ )
296
+
285
297
return tagged_graph_module
286
298
287
299
288
300
def _partition_and_lower (
289
301
tagged_graph_module : torch .fx .GraphModule ,
290
302
partition_result : PartitionResult ,
291
303
owning_program : ExportedProgram ,
304
+ is_submodule : bool = False ,
292
305
) -> torch .fx .GraphModule :
293
306
"""
294
307
Partitions the graph module into submodules based on tags, and then lowered the nodes with the same tag as one lowered module, including the submodule from control flow
295
308
"""
296
309
297
310
partitioned_module = _partition_and_lower_one_graph_module (
298
- tagged_graph_module , partition_result , owning_program
311
+ tagged_graph_module , partition_result , owning_program , is_submodule
299
312
)
300
313
301
314
# Recursively partition and lower for submodules
302
315
for name , submod , _node in get_control_flow_submodules (partitioned_module ):
303
316
partitioned_submodule = _partition_and_lower (
304
- submod , partition_result , owning_program
317
+ submod , partition_result , owning_program , is_submodule = True
305
318
)
306
319
tagged_graph_module .add_module (name , partitioned_submodule )
307
320
308
- # Run the export pass over the graph module so that the call delegate
309
- # nodes will match Edge dialect
310
- # TODO(angelayi): ExportPass will rerun the graph, however all we need
311
- # here is to add metadata to the call delegate nodes to preserve Edge
312
- # dialect. There's work going on to generate a random tensor from a
313
- # fake tensor and possibly it can help to address the issue.
314
- res = ExportPass ()(tagged_graph_module )
315
- assert res is not None
316
- tagged_graph_module = res .graph_module
317
-
318
321
return tagged_graph_module
319
322
320
323
@@ -349,6 +352,8 @@ def to_backend(
349
352
Returns:
350
353
ExportedProgram: The input program, with some portions targeted for delegation.
351
354
"""
355
+ edge_program ._validate ()
356
+
352
357
# Use fake program, with FakeTensors in the state dict, to avoid copying large constant values.
353
358
# Fall back to deepcopy if no fake mode is found. TODO(T182910699): Remove this fallback.
354
359
try :
@@ -377,26 +382,22 @@ def to_backend(
377
382
update_to_real_program (tagged_exported_program , edge_program )
378
383
379
384
for tag , _ in partitioner_result .partition_tags .items ():
380
- _maybe_duplicate_constant_nodes (tagged_exported_program , tag , edge_program )
385
+ _maybe_duplicate_constant_nodes (tagged_exported_program , tag )
381
386
382
387
tagged_graph_module = _partition_and_lower (
383
- tagged_exported_program .graph_module , partitioner_result , edge_program
388
+ tagged_exported_program .graph_module ,
389
+ partitioner_result ,
390
+ tagged_exported_program ,
384
391
)
385
392
386
- # TODO(angelayi): Update this signature in a less manual way (maybe through
387
- # retracing)
388
- new_signature , new_state_dict , new_constants = _get_new_signature (
389
- edge_program ,
390
- tagged_graph_module ,
391
- )
392
393
return ExportedProgram (
393
394
root = tagged_graph_module ,
394
395
graph = tagged_graph_module .graph ,
395
- graph_signature = new_signature ,
396
- state_dict = new_state_dict ,
397
- range_constraints = copy .deepcopy (edge_program .range_constraints ),
398
- module_call_graph = copy .deepcopy (edge_program .module_call_graph ),
396
+ graph_signature = tagged_exported_program . graph_signature ,
397
+ state_dict = tagged_exported_program . state_dict ,
398
+ range_constraints = copy .deepcopy (tagged_exported_program .range_constraints ),
399
+ module_call_graph = copy .deepcopy (tagged_exported_program .module_call_graph ),
399
400
example_inputs = None ,
400
- constants = new_constants ,
401
- verifiers = [edge_program .verifier ],
401
+ constants = tagged_exported_program . constants ,
402
+ verifiers = [tagged_exported_program .verifier ],
402
403
)
0 commit comments