Skip to content

Refactor delegation code #4566

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Aug 15, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
119 changes: 60 additions & 59 deletions exir/backend/backend_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@

import copy
import logging
from contextlib import contextmanager
from contextlib import contextmanager, nullcontext
from functools import singledispatch
from typing import Generator, List

Expand All @@ -25,12 +25,11 @@

from executorch.exir.graph_module import get_control_flow_submodules
from executorch.exir.lowered_backend_module import (
_get_new_signature,
_unsafe_adjust_original_program,
create_exported_program_from_submodule,
create_submodule_from_nodes,
LoweredBackendModule,
)
from executorch.exir.pass_base import ExportPass
from executorch.exir.program._fake_program import (
get_fake_program,
update_to_real_program,
Expand Down Expand Up @@ -193,6 +192,7 @@ def _partition_and_lower_one_graph_module(
tagged_graph_module: torch.fx.GraphModule,
partition_result: PartitionResult,
owning_program: ExportedProgram,
is_submodule: bool,
) -> torch.fx.GraphModule:
"""
Partitioned and lowered the graph module based on the partition tag, this is to handle one graph module.
Expand All @@ -210,21 +210,40 @@ def _partition_and_lower_one_graph_module(

logging.debug(f"For tag {tag}, found nodes {node_list}")
# Tag the nodes that are params as buffers, so we can order the submodule as (Parms + Buffers) (User Inputs)
submodule, call_module_node = create_submodule_from_nodes(
tagged_graph_module, node_list, tag

replace_ctx = (
tagged_graph_module._set_replace_hook(
owning_program.graph_signature.get_replace_hook()
)
if not is_submodule
else nullcontext()
)
with replace_ctx:
submodule, call_module_node = create_submodule_from_nodes(
tagged_graph_module, node_list, tag
)

tagged_graph_module_output_node = [
node for node in tagged_graph_module.graph.nodes if node.op == "output"
]
][0]
submodule_output_node = [
node for node in submodule.graph.nodes if node.op == "output"
]
# Copy the output node meta from the original output node, because create_submodule_from_nodes doesn't cover the meta field
submodule_output_node[0].meta = tagged_graph_module_output_node[0].meta
][0]
# Copy the output node meta from the original output node, because
# create_submodule_from_nodes doesn't cover the meta field
submodule_output_node.meta = tagged_graph_module_output_node.meta
logging.debug(f"Partitioned graph module: {tagged_graph_module}")

submodule_program = create_exported_program_from_submodule(
submodule, owning_program, tag
(
submodule_program,
toplevel_input_specs_to_delete,
toplevel_output_specs_to_delete,
) = create_exported_program_from_submodule(
submodule,
owning_program,
tag,
call_module_node,
is_submodule,
)

lowered_submodule = to_backend(
Expand Down Expand Up @@ -257,64 +276,48 @@ def _partition_and_lower_one_graph_module(
call_delegate_node.meta["debug_handle"] = len(
tagged_graph_module.graph.nodes
)
call_delegate_node.meta["val"] = submodule_output_node.meta["val"]
call_module_node.replace_all_uses_with(call_delegate_node)
tagged_graph_module.graph.erase_node(call_module_node)

# Delete all parameters/buffers consumed by the created exported program
toplevel_signature = owning_program.graph_signature
for node in tagged_graph_module.graph.nodes:
# Find placeholders consumed by the delegate
if node.op != "placeholder" or len(node.users) != 0:
continue

if node.name in toplevel_signature.inputs_to_buffers:
# Delete the consumed buffers
buffer_name = toplevel_signature.inputs_to_buffers.get(node.name)
if buffer_name in owning_program.state_dict:
owning_program.state_dict.pop(buffer_name)
else:
owning_program.constants.pop(buffer_name)
tagged_graph_module.graph.erase_node(node)
elif node.name in toplevel_signature.inputs_to_parameters:
# Delete the consumed parameters
param_name = toplevel_signature.inputs_to_parameters.get(node.name)
owning_program.state_dict.pop(param_name)
tagged_graph_module.graph.erase_node(node)

tagged_graph_module.recompile()
if is_submodule:
assert len(toplevel_input_specs_to_delete) == 0
assert len(toplevel_output_specs_to_delete) == 0
elif (
len(toplevel_input_specs_to_delete) > 0
or len(toplevel_output_specs_to_delete) > 0
):
_unsafe_adjust_original_program(
owning_program,
call_delegate_node,
toplevel_input_specs_to_delete,
toplevel_output_specs_to_delete,
)

return tagged_graph_module


def _partition_and_lower(
tagged_graph_module: torch.fx.GraphModule,
partition_result: PartitionResult,
owning_program: ExportedProgram,
is_submodule: bool = False,
) -> torch.fx.GraphModule:
"""
Partitions the graph module into submodules based on tags, and then lowered the nodes with the same tag as one lowered module, including the submodule from control flow
"""

partitioned_module = _partition_and_lower_one_graph_module(
tagged_graph_module, partition_result, owning_program
tagged_graph_module, partition_result, owning_program, is_submodule
)

# Recursively partition and lower for submodules
for name, submod, _node in get_control_flow_submodules(partitioned_module):
partitioned_submodule = _partition_and_lower(
submod, partition_result, owning_program
submod, partition_result, owning_program, is_submodule=True
)
tagged_graph_module.add_module(name, partitioned_submodule)

# Run the export pass over the graph module so that the call delegate
# nodes will match Edge dialect
# TODO(angelayi): ExportPass will rerun the graph, however all we need
# here is to add metadata to the call delegate nodes to preserve Edge
# dialect. There's work going on to generate a random tensor from a
# fake tensor and possibly it can help to address the issue.
res = ExportPass()(tagged_graph_module)
assert res is not None
tagged_graph_module = res.graph_module

return tagged_graph_module


Expand Down Expand Up @@ -349,6 +352,8 @@ def to_backend(
Returns:
ExportedProgram: The input program, with some portions targeted for delegation.
"""
edge_program._validate()

# Use fake program, with FakeTensors in the state dict, to avoid copying large constant values.
# Fall back to deepcopy if no fake mode is found. TODO(T182910699): Remove this fallback.
try:
Expand Down Expand Up @@ -377,26 +382,22 @@ def to_backend(
update_to_real_program(tagged_exported_program, edge_program)

for tag, _ in partitioner_result.partition_tags.items():
_maybe_duplicate_constant_nodes(tagged_exported_program, tag, edge_program)
_maybe_duplicate_constant_nodes(tagged_exported_program, tag)

tagged_graph_module = _partition_and_lower(
tagged_exported_program.graph_module, partitioner_result, edge_program
tagged_exported_program.graph_module,
partitioner_result,
tagged_exported_program,
)

# TODO(angelayi): Update this signature in a less manual way (maybe through
# retracing)
new_signature, new_state_dict, new_constants = _get_new_signature(
edge_program,
tagged_graph_module,
)
return ExportedProgram(
root=tagged_graph_module,
graph=tagged_graph_module.graph,
graph_signature=new_signature,
state_dict=new_state_dict,
range_constraints=copy.deepcopy(edge_program.range_constraints),
module_call_graph=copy.deepcopy(edge_program.module_call_graph),
graph_signature=tagged_exported_program.graph_signature,
state_dict=tagged_exported_program.state_dict,
range_constraints=copy.deepcopy(tagged_exported_program.range_constraints),
module_call_graph=copy.deepcopy(tagged_exported_program.module_call_graph),
example_inputs=None,
constants=new_constants,
verifiers=[edge_program.verifier],
constants=tagged_exported_program.constants,
verifiers=[tagged_exported_program.verifier],
)
24 changes: 1 addition & 23 deletions exir/backend/test/test_backends.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,10 +35,7 @@
from executorch.exir.delegate import executorch_call_delegate
from executorch.exir.dialects._ops import ops as exir_ops
from executorch.exir.graph_module import get_control_flow_submodules
from executorch.exir.lowered_backend_module import (
_get_new_signature,
get_lowered_submodules,
)
from executorch.exir.lowered_backend_module import get_lowered_submodules
from executorch.exir.print_program import print_program
from executorch.exir.schema import (
BackendDelegate,
Expand All @@ -63,7 +60,6 @@
prepare_fx,
)
from torch.export import ExportedProgram
from torch.export.exported_program import OutputKind, TensorArgument
from torch.testing import FileCheck


Expand Down Expand Up @@ -1270,21 +1266,3 @@ def forward(self, x: List[torch.Tensor]):

gm = exir.capture(ComposedM(), inputs, exir.CaptureConfig()).to_edge()
gm(*inputs)

def test_get_new_signature(self):
class MyModule(torch.nn.Module):
def forward(self, x, y, z):
return x + y, y - z, z * x

ep = torch.export.export(
MyModule(), (torch.randn(3, 2), torch.randn(3, 2), torch.randn(3, 2))
)
sig, *_ = _get_new_signature(ep, ep.graph_module)
output_names = set()
self.assertEqual(len(sig.output_specs), 3)
for s in sig.output_specs:
self.assertEqual(s.kind, OutputKind.USER_OUTPUT)
self.assertIsInstance(s.arg, TensorArgument)
name = s.arg.name
self.assertNotIn(name, output_names)
output_names.add(name)
2 changes: 0 additions & 2 deletions exir/backend/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -208,7 +208,6 @@ def _assign_new_tag(
def _maybe_duplicate_constant_nodes(
tagged_exported_program: ExportedProgram,
tag: str,
owning_program: ExportedProgram,
) -> None:
"""
If the constants node is shared by different tagged nodes, like
Expand Down Expand Up @@ -241,7 +240,6 @@ def _maybe_duplicate_constant_nodes(
copied_nodes = copied_nodes.union(
duplicate_constant_node(tagged_exported_program, candidate_node)
)
duplicate_constant_node(owning_program, candidate_node)
candidate_node_with_copies = candidate_nodes.union(copied_nodes)
_assign_new_tag(tagged_exported_program, candidate_node_with_copies)

Expand Down
Loading
Loading