Skip to content

Commit 3e4508a

Browse files
authored
Refactor delegation code
Differential Revision: D60813405 Pull Request resolved: #4566
1 parent ae299cf commit 3e4508a

File tree

4 files changed

+252
-185
lines changed

4 files changed

+252
-185
lines changed

exir/backend/backend_api.py

Lines changed: 60 additions & 59 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66

77
import copy
88
import logging
9-
from contextlib import contextmanager
9+
from contextlib import contextmanager, nullcontext
1010
from functools import singledispatch
1111
from typing import Generator, List
1212

@@ -25,12 +25,11 @@
2525

2626
from executorch.exir.graph_module import get_control_flow_submodules
2727
from executorch.exir.lowered_backend_module import (
28-
_get_new_signature,
28+
_unsafe_adjust_original_program,
2929
create_exported_program_from_submodule,
3030
create_submodule_from_nodes,
3131
LoweredBackendModule,
3232
)
33-
from executorch.exir.pass_base import ExportPass
3433
from executorch.exir.program._fake_program import (
3534
get_fake_program,
3635
update_to_real_program,
@@ -193,6 +192,7 @@ def _partition_and_lower_one_graph_module(
193192
tagged_graph_module: torch.fx.GraphModule,
194193
partition_result: PartitionResult,
195194
owning_program: ExportedProgram,
195+
is_submodule: bool,
196196
) -> torch.fx.GraphModule:
197197
"""
198198
Partitioned and lowered the graph module based on the partition tag, this is to handle one graph module.
@@ -210,21 +210,40 @@ def _partition_and_lower_one_graph_module(
210210

211211
logging.debug(f"For tag {tag}, found nodes {node_list}")
212212
# Tag the nodes that are params as buffers, so we can order the submodule as (Parms + Buffers) (User Inputs)
213-
submodule, call_module_node = create_submodule_from_nodes(
214-
tagged_graph_module, node_list, tag
213+
214+
replace_ctx = (
215+
tagged_graph_module._set_replace_hook(
216+
owning_program.graph_signature.get_replace_hook()
217+
)
218+
if not is_submodule
219+
else nullcontext()
215220
)
221+
with replace_ctx:
222+
submodule, call_module_node = create_submodule_from_nodes(
223+
tagged_graph_module, node_list, tag
224+
)
225+
216226
tagged_graph_module_output_node = [
217227
node for node in tagged_graph_module.graph.nodes if node.op == "output"
218-
]
228+
][0]
219229
submodule_output_node = [
220230
node for node in submodule.graph.nodes if node.op == "output"
221-
]
222-
# Copy the output node meta from the original output node, because create_submodule_from_nodes doesn't cover the meta field
223-
submodule_output_node[0].meta = tagged_graph_module_output_node[0].meta
231+
][0]
232+
# Copy the output node meta from the original output node, because
233+
# create_submodule_from_nodes doesn't cover the meta field
234+
submodule_output_node.meta = tagged_graph_module_output_node.meta
224235
logging.debug(f"Partitioned graph module: {tagged_graph_module}")
225236

226-
submodule_program = create_exported_program_from_submodule(
227-
submodule, owning_program, tag
237+
(
238+
submodule_program,
239+
toplevel_input_specs_to_delete,
240+
toplevel_output_specs_to_delete,
241+
) = create_exported_program_from_submodule(
242+
submodule,
243+
owning_program,
244+
tag,
245+
call_module_node,
246+
is_submodule,
228247
)
229248

230249
lowered_submodule = to_backend(
@@ -257,64 +276,48 @@ def _partition_and_lower_one_graph_module(
257276
call_delegate_node.meta["debug_handle"] = len(
258277
tagged_graph_module.graph.nodes
259278
)
279+
call_delegate_node.meta["val"] = submodule_output_node.meta["val"]
260280
call_module_node.replace_all_uses_with(call_delegate_node)
261281
tagged_graph_module.graph.erase_node(call_module_node)
262282

263-
# Delete all parameters/buffers consumed by the created exported program
264-
toplevel_signature = owning_program.graph_signature
265-
for node in tagged_graph_module.graph.nodes:
266-
# Find placeholders consumed by the delegate
267-
if node.op != "placeholder" or len(node.users) != 0:
268-
continue
269-
270-
if node.name in toplevel_signature.inputs_to_buffers:
271-
# Delete the consumed buffers
272-
buffer_name = toplevel_signature.inputs_to_buffers.get(node.name)
273-
if buffer_name in owning_program.state_dict:
274-
owning_program.state_dict.pop(buffer_name)
275-
else:
276-
owning_program.constants.pop(buffer_name)
277-
tagged_graph_module.graph.erase_node(node)
278-
elif node.name in toplevel_signature.inputs_to_parameters:
279-
# Delete the consumed parameters
280-
param_name = toplevel_signature.inputs_to_parameters.get(node.name)
281-
owning_program.state_dict.pop(param_name)
282-
tagged_graph_module.graph.erase_node(node)
283-
284-
tagged_graph_module.recompile()
283+
if is_submodule:
284+
assert len(toplevel_input_specs_to_delete) == 0
285+
assert len(toplevel_output_specs_to_delete) == 0
286+
elif (
287+
len(toplevel_input_specs_to_delete) > 0
288+
or len(toplevel_output_specs_to_delete) > 0
289+
):
290+
_unsafe_adjust_original_program(
291+
owning_program,
292+
call_delegate_node,
293+
toplevel_input_specs_to_delete,
294+
toplevel_output_specs_to_delete,
295+
)
296+
285297
return tagged_graph_module
286298

287299

288300
def _partition_and_lower(
289301
tagged_graph_module: torch.fx.GraphModule,
290302
partition_result: PartitionResult,
291303
owning_program: ExportedProgram,
304+
is_submodule: bool = False,
292305
) -> torch.fx.GraphModule:
293306
"""
294307
Partitions the graph module into submodules based on tags, and then lowered the nodes with the same tag as one lowered module, including the submodule from control flow
295308
"""
296309

297310
partitioned_module = _partition_and_lower_one_graph_module(
298-
tagged_graph_module, partition_result, owning_program
311+
tagged_graph_module, partition_result, owning_program, is_submodule
299312
)
300313

301314
# Recursively partition and lower for submodules
302315
for name, submod, _node in get_control_flow_submodules(partitioned_module):
303316
partitioned_submodule = _partition_and_lower(
304-
submod, partition_result, owning_program
317+
submod, partition_result, owning_program, is_submodule=True
305318
)
306319
tagged_graph_module.add_module(name, partitioned_submodule)
307320

308-
# Run the export pass over the graph module so that the call delegate
309-
# nodes will match Edge dialect
310-
# TODO(angelayi): ExportPass will rerun the graph, however all we need
311-
# here is to add metadata to the call delegate nodes to preserve Edge
312-
# dialect. There's work going on to generate a random tensor from a
313-
# fake tensor and possibly it can help to address the issue.
314-
res = ExportPass()(tagged_graph_module)
315-
assert res is not None
316-
tagged_graph_module = res.graph_module
317-
318321
return tagged_graph_module
319322

320323

@@ -349,6 +352,8 @@ def to_backend(
349352
Returns:
350353
ExportedProgram: The input program, with some portions targeted for delegation.
351354
"""
355+
edge_program._validate()
356+
352357
# Use fake program, with FakeTensors in the state dict, to avoid copying large constant values.
353358
# Fall back to deepcopy if no fake mode is found. TODO(T182910699): Remove this fallback.
354359
try:
@@ -377,26 +382,22 @@ def to_backend(
377382
update_to_real_program(tagged_exported_program, edge_program)
378383

379384
for tag, _ in partitioner_result.partition_tags.items():
380-
_maybe_duplicate_constant_nodes(tagged_exported_program, tag, edge_program)
385+
_maybe_duplicate_constant_nodes(tagged_exported_program, tag)
381386

382387
tagged_graph_module = _partition_and_lower(
383-
tagged_exported_program.graph_module, partitioner_result, edge_program
388+
tagged_exported_program.graph_module,
389+
partitioner_result,
390+
tagged_exported_program,
384391
)
385392

386-
# TODO(angelayi): Update this signature in a less manual way (maybe through
387-
# retracing)
388-
new_signature, new_state_dict, new_constants = _get_new_signature(
389-
edge_program,
390-
tagged_graph_module,
391-
)
392393
return ExportedProgram(
393394
root=tagged_graph_module,
394395
graph=tagged_graph_module.graph,
395-
graph_signature=new_signature,
396-
state_dict=new_state_dict,
397-
range_constraints=copy.deepcopy(edge_program.range_constraints),
398-
module_call_graph=copy.deepcopy(edge_program.module_call_graph),
396+
graph_signature=tagged_exported_program.graph_signature,
397+
state_dict=tagged_exported_program.state_dict,
398+
range_constraints=copy.deepcopy(tagged_exported_program.range_constraints),
399+
module_call_graph=copy.deepcopy(tagged_exported_program.module_call_graph),
399400
example_inputs=None,
400-
constants=new_constants,
401-
verifiers=[edge_program.verifier],
401+
constants=tagged_exported_program.constants,
402+
verifiers=[tagged_exported_program.verifier],
402403
)

exir/backend/test/test_backends.py

Lines changed: 1 addition & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -35,10 +35,7 @@
3535
from executorch.exir.delegate import executorch_call_delegate
3636
from executorch.exir.dialects._ops import ops as exir_ops
3737
from executorch.exir.graph_module import get_control_flow_submodules
38-
from executorch.exir.lowered_backend_module import (
39-
_get_new_signature,
40-
get_lowered_submodules,
41-
)
38+
from executorch.exir.lowered_backend_module import get_lowered_submodules
4239
from executorch.exir.print_program import print_program
4340
from executorch.exir.schema import (
4441
BackendDelegate,
@@ -63,7 +60,6 @@
6360
prepare_fx,
6461
)
6562
from torch.export import ExportedProgram
66-
from torch.export.exported_program import OutputKind, TensorArgument
6763
from torch.testing import FileCheck
6864

6965

@@ -1270,21 +1266,3 @@ def forward(self, x: List[torch.Tensor]):
12701266

12711267
gm = exir.capture(ComposedM(), inputs, exir.CaptureConfig()).to_edge()
12721268
gm(*inputs)
1273-
1274-
def test_get_new_signature(self):
1275-
class MyModule(torch.nn.Module):
1276-
def forward(self, x, y, z):
1277-
return x + y, y - z, z * x
1278-
1279-
ep = torch.export.export(
1280-
MyModule(), (torch.randn(3, 2), torch.randn(3, 2), torch.randn(3, 2))
1281-
)
1282-
sig, *_ = _get_new_signature(ep, ep.graph_module)
1283-
output_names = set()
1284-
self.assertEqual(len(sig.output_specs), 3)
1285-
for s in sig.output_specs:
1286-
self.assertEqual(s.kind, OutputKind.USER_OUTPUT)
1287-
self.assertIsInstance(s.arg, TensorArgument)
1288-
name = s.arg.name
1289-
self.assertNotIn(name, output_names)
1290-
output_names.add(name)

exir/backend/utils.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -208,7 +208,6 @@ def _assign_new_tag(
208208
def _maybe_duplicate_constant_nodes(
209209
tagged_exported_program: ExportedProgram,
210210
tag: str,
211-
owning_program: ExportedProgram,
212211
) -> None:
213212
"""
214213
If the constants node is shared by different tagged nodes, like
@@ -241,7 +240,6 @@ def _maybe_duplicate_constant_nodes(
241240
copied_nodes = copied_nodes.union(
242241
duplicate_constant_node(tagged_exported_program, candidate_node)
243242
)
244-
duplicate_constant_node(owning_program, candidate_node)
245243
candidate_node_with_copies = candidate_nodes.union(copied_nodes)
246244
_assign_new_tag(tagged_exported_program, candidate_node_with_copies)
247245

0 commit comments

Comments
 (0)