From 7836f9d8bf8247b69d36a527ef7d041d3f463633 Mon Sep 17 00:00:00 2001 From: Angela Yi Date: Thu, 15 Aug 2024 09:04:37 -0700 Subject: [PATCH] Refactor delegation code (#4566) Summary: X-link: https://github.com/pytorch/pytorch/pull/132773 Pull Request resolved: https://github.com/pytorch/executorch/pull/4566 Refactoring partitioner-based delegation to prepare for allowing buffer mutations in the delegate (following diff). Reviewed By: cccclai Differential Revision: D60813405 --- exir/backend/backend_api.py | 119 ++++++------ exir/backend/test/test_backends.py | 24 +-- exir/backend/utils.py | 2 - exir/lowered_backend_module.py | 292 +++++++++++++++++++---------- 4 files changed, 252 insertions(+), 185 deletions(-) diff --git a/exir/backend/backend_api.py b/exir/backend/backend_api.py index 25c793287d7..d114d8b4705 100644 --- a/exir/backend/backend_api.py +++ b/exir/backend/backend_api.py @@ -6,7 +6,7 @@ import copy import logging -from contextlib import contextmanager +from contextlib import contextmanager, nullcontext from functools import singledispatch from typing import Generator, List @@ -25,12 +25,11 @@ from executorch.exir.graph_module import get_control_flow_submodules from executorch.exir.lowered_backend_module import ( - _get_new_signature, + _unsafe_adjust_original_program, create_exported_program_from_submodule, create_submodule_from_nodes, LoweredBackendModule, ) -from executorch.exir.pass_base import ExportPass from executorch.exir.program._fake_program import ( get_fake_program, update_to_real_program, @@ -193,6 +192,7 @@ def _partition_and_lower_one_graph_module( tagged_graph_module: torch.fx.GraphModule, partition_result: PartitionResult, owning_program: ExportedProgram, + is_submodule: bool, ) -> torch.fx.GraphModule: """ Partitioned and lowered the graph module based on the partition tag, this is to handle one graph module. @@ -210,21 +210,40 @@ def _partition_and_lower_one_graph_module( logging.debug(f"For tag {tag}, found nodes {node_list}") # Tag the nodes that are params as buffers, so we can order the submodule as (Parms + Buffers) (User Inputs) - submodule, call_module_node = create_submodule_from_nodes( - tagged_graph_module, node_list, tag + + replace_ctx = ( + tagged_graph_module._set_replace_hook( + owning_program.graph_signature.get_replace_hook() + ) + if not is_submodule + else nullcontext() ) + with replace_ctx: + submodule, call_module_node = create_submodule_from_nodes( + tagged_graph_module, node_list, tag + ) + tagged_graph_module_output_node = [ node for node in tagged_graph_module.graph.nodes if node.op == "output" - ] + ][0] submodule_output_node = [ node for node in submodule.graph.nodes if node.op == "output" - ] - # Copy the output node meta from the original output node, because create_submodule_from_nodes doesn't cover the meta field - submodule_output_node[0].meta = tagged_graph_module_output_node[0].meta + ][0] + # Copy the output node meta from the original output node, because + # create_submodule_from_nodes doesn't cover the meta field + submodule_output_node.meta = tagged_graph_module_output_node.meta logging.debug(f"Partitioned graph module: {tagged_graph_module}") - submodule_program = create_exported_program_from_submodule( - submodule, owning_program, tag + ( + submodule_program, + toplevel_input_specs_to_delete, + toplevel_output_specs_to_delete, + ) = create_exported_program_from_submodule( + submodule, + owning_program, + tag, + call_module_node, + is_submodule, ) lowered_submodule = to_backend( @@ -257,31 +276,24 @@ def _partition_and_lower_one_graph_module( call_delegate_node.meta["debug_handle"] = len( tagged_graph_module.graph.nodes ) + call_delegate_node.meta["val"] = submodule_output_node.meta["val"] call_module_node.replace_all_uses_with(call_delegate_node) tagged_graph_module.graph.erase_node(call_module_node) - # Delete all parameters/buffers consumed by the created exported program - toplevel_signature = owning_program.graph_signature - for node in tagged_graph_module.graph.nodes: - # Find placeholders consumed by the delegate - if node.op != "placeholder" or len(node.users) != 0: - continue - - if node.name in toplevel_signature.inputs_to_buffers: - # Delete the consumed buffers - buffer_name = toplevel_signature.inputs_to_buffers.get(node.name) - if buffer_name in owning_program.state_dict: - owning_program.state_dict.pop(buffer_name) - else: - owning_program.constants.pop(buffer_name) - tagged_graph_module.graph.erase_node(node) - elif node.name in toplevel_signature.inputs_to_parameters: - # Delete the consumed parameters - param_name = toplevel_signature.inputs_to_parameters.get(node.name) - owning_program.state_dict.pop(param_name) - tagged_graph_module.graph.erase_node(node) - - tagged_graph_module.recompile() + if is_submodule: + assert len(toplevel_input_specs_to_delete) == 0 + assert len(toplevel_output_specs_to_delete) == 0 + elif ( + len(toplevel_input_specs_to_delete) > 0 + or len(toplevel_output_specs_to_delete) > 0 + ): + _unsafe_adjust_original_program( + owning_program, + call_delegate_node, + toplevel_input_specs_to_delete, + toplevel_output_specs_to_delete, + ) + return tagged_graph_module @@ -289,32 +301,23 @@ def _partition_and_lower( tagged_graph_module: torch.fx.GraphModule, partition_result: PartitionResult, owning_program: ExportedProgram, + is_submodule: bool = False, ) -> torch.fx.GraphModule: """ Partitions the graph module into submodules based on tags, and then lowered the nodes with the same tag as one lowered module, including the submodule from control flow """ partitioned_module = _partition_and_lower_one_graph_module( - tagged_graph_module, partition_result, owning_program + tagged_graph_module, partition_result, owning_program, is_submodule ) # Recursively partition and lower for submodules for name, submod, _node in get_control_flow_submodules(partitioned_module): partitioned_submodule = _partition_and_lower( - submod, partition_result, owning_program + submod, partition_result, owning_program, is_submodule=True ) tagged_graph_module.add_module(name, partitioned_submodule) - # Run the export pass over the graph module so that the call delegate - # nodes will match Edge dialect - # TODO(angelayi): ExportPass will rerun the graph, however all we need - # here is to add metadata to the call delegate nodes to preserve Edge - # dialect. There's work going on to generate a random tensor from a - # fake tensor and possibly it can help to address the issue. - res = ExportPass()(tagged_graph_module) - assert res is not None - tagged_graph_module = res.graph_module - return tagged_graph_module @@ -349,6 +352,8 @@ def to_backend( Returns: ExportedProgram: The input program, with some portions targeted for delegation. """ + edge_program._validate() + # Use fake program, with FakeTensors in the state dict, to avoid copying large constant values. # Fall back to deepcopy if no fake mode is found. TODO(T182910699): Remove this fallback. try: @@ -377,26 +382,22 @@ def to_backend( update_to_real_program(tagged_exported_program, edge_program) for tag, _ in partitioner_result.partition_tags.items(): - _maybe_duplicate_constant_nodes(tagged_exported_program, tag, edge_program) + _maybe_duplicate_constant_nodes(tagged_exported_program, tag) tagged_graph_module = _partition_and_lower( - tagged_exported_program.graph_module, partitioner_result, edge_program + tagged_exported_program.graph_module, + partitioner_result, + tagged_exported_program, ) - # TODO(angelayi): Update this signature in a less manual way (maybe through - # retracing) - new_signature, new_state_dict, new_constants = _get_new_signature( - edge_program, - tagged_graph_module, - ) return ExportedProgram( root=tagged_graph_module, graph=tagged_graph_module.graph, - graph_signature=new_signature, - state_dict=new_state_dict, - range_constraints=copy.deepcopy(edge_program.range_constraints), - module_call_graph=copy.deepcopy(edge_program.module_call_graph), + graph_signature=tagged_exported_program.graph_signature, + state_dict=tagged_exported_program.state_dict, + range_constraints=copy.deepcopy(tagged_exported_program.range_constraints), + module_call_graph=copy.deepcopy(tagged_exported_program.module_call_graph), example_inputs=None, - constants=new_constants, - verifiers=[edge_program.verifier], + constants=tagged_exported_program.constants, + verifiers=[tagged_exported_program.verifier], ) diff --git a/exir/backend/test/test_backends.py b/exir/backend/test/test_backends.py index 72d61fdf4bf..8c1b7c47ff4 100644 --- a/exir/backend/test/test_backends.py +++ b/exir/backend/test/test_backends.py @@ -35,10 +35,7 @@ from executorch.exir.delegate import executorch_call_delegate from executorch.exir.dialects._ops import ops as exir_ops from executorch.exir.graph_module import get_control_flow_submodules -from executorch.exir.lowered_backend_module import ( - _get_new_signature, - get_lowered_submodules, -) +from executorch.exir.lowered_backend_module import get_lowered_submodules from executorch.exir.print_program import print_program from executorch.exir.schema import ( BackendDelegate, @@ -63,7 +60,6 @@ prepare_fx, ) from torch.export import ExportedProgram -from torch.export.exported_program import OutputKind, TensorArgument from torch.testing import FileCheck @@ -1270,21 +1266,3 @@ def forward(self, x: List[torch.Tensor]): gm = exir.capture(ComposedM(), inputs, exir.CaptureConfig()).to_edge() gm(*inputs) - - def test_get_new_signature(self): - class MyModule(torch.nn.Module): - def forward(self, x, y, z): - return x + y, y - z, z * x - - ep = torch.export.export( - MyModule(), (torch.randn(3, 2), torch.randn(3, 2), torch.randn(3, 2)) - ) - sig, *_ = _get_new_signature(ep, ep.graph_module) - output_names = set() - self.assertEqual(len(sig.output_specs), 3) - for s in sig.output_specs: - self.assertEqual(s.kind, OutputKind.USER_OUTPUT) - self.assertIsInstance(s.arg, TensorArgument) - name = s.arg.name - self.assertNotIn(name, output_names) - output_names.add(name) diff --git a/exir/backend/utils.py b/exir/backend/utils.py index 68f36cdb7b4..b5072604d2d 100644 --- a/exir/backend/utils.py +++ b/exir/backend/utils.py @@ -208,7 +208,6 @@ def _assign_new_tag( def _maybe_duplicate_constant_nodes( tagged_exported_program: ExportedProgram, tag: str, - owning_program: ExportedProgram, ) -> None: """ If the constants node is shared by different tagged nodes, like @@ -241,7 +240,6 @@ def _maybe_duplicate_constant_nodes( copied_nodes = copied_nodes.union( duplicate_constant_node(tagged_exported_program, candidate_node) ) - duplicate_constant_node(owning_program, candidate_node) candidate_node_with_copies = candidate_nodes.union(copied_nodes) _assign_new_tag(tagged_exported_program, candidate_node_with_copies) diff --git a/exir/lowered_backend_module.py b/exir/lowered_backend_module.py index 54d73802036..2c2cd8eb0dd 100644 --- a/exir/lowered_backend_module.py +++ b/exir/lowered_backend_module.py @@ -8,7 +8,7 @@ import copy import operator -from typing import Any, Dict, List, Optional, Tuple, Union +from typing import Any, Dict, List, Optional, Set, Tuple, Union import torch import torch.utils._pytree as pytree @@ -432,46 +432,67 @@ def arrange_graph_placeholders( def _get_new_signature( # noqa: C901 original_program: ExportedProgram, gm: torch.fx.GraphModule, - tag: Optional[str] = None, + call_module_node: torch.fx.Node, + tag: str, + is_submodule: bool = False, ) -> Tuple[ ExportGraphSignature, Dict[str, Union[torch.Tensor, torch.nn.Parameter]], Dict[str, Union[torch.Tensor, FakeScriptObject, torch.ScriptObject]], + Dict[str, InputSpec], + Dict[str, OutputSpec], ]: """ Args: - tag: If tag is None, this means that we are constructing the graph - signature for the toplevel graph, after delegation. We need to do this - because sometimes delegates will swallow some parameters/buffers, so we - need to update the graph signature/state dict to reflect these changes. - Otherwise, if tag is not None, this means we are constructing the graph - signature for the delegated modules. In this case, we need to look - through the input nodes and see which ones were originally - parameters/buffers, and lower them down to the delegate. - """ + original_program: The original program that we are paritioning + gm: The partitioned graph module. + call_module_node: The node in the original program that is calling the + partitioned graph module. + tag: The tag being used for this partitioned submodule. This is used to + tell if a particular parameter/buffer/constant node is being tagged, + aka consumed by the delegate. + is_submodule: True if we are currently partitioning inside of a + submodule (like cond's submodule). If we are inside of a submodule, + we do not care about consuming params/buffers. + + Returns: + new_signature (ExportGraphSignature): The new signature for the + partitioned graph module. + new_state_dict (Dict[str, Union[torch.Tensor, torch.nn.Parameter]]): The + new state dict containing the consumed params/buffers. + new_constants (Dict[str, Union[torch.Tensor, FakeScriptObject, + torch.ScriptObject]]): The new constants table containing the + consumed constants . + input_specs_to_delete (Dict[str, InputSpec]): The input specs that have + been consumed by the delegate (param/buffer input nodes) and should + be removed from the toplevel ExportedProgram. + output_specs_to_delete (Dict[str, InputSpec]): The output specs that have + been consumed by the delegate (buffer mutation nodes) and should be + removed from the toplevel ExportedProgram. + """ old_signature = original_program.graph_signature input_specs = [] output_specs = [] - new_signature = ExportGraphSignature( - input_specs=input_specs, output_specs=output_specs - ) + input_specs_to_delete = {} + output_specs_to_delete = {} new_state_dict = {} new_constants = {} - placeholder_nodes = [ - node.name for node in original_program.graph.nodes if node.op == "placeholder" - ] - assert len(placeholder_nodes) == len(old_signature.input_specs) - input_node_to_sig = dict(zip(placeholder_nodes, old_signature.input_specs)) + # If we are within a submodule, we do not need to care about consuming + # parameter/buffers + input_node_to_sig: Dict[str, InputSpec] = ( + {input_spec.arg.name: input_spec for input_spec in old_signature.input_specs} + if not is_submodule + else {} + ) for node in gm.graph.nodes: is_tagged = tag is None or node.meta.get("delegation_tag", None) == tag if node.op == "placeholder": if node.name not in input_node_to_sig: - assert tag is not None input_specs.append( InputSpec( kind=InputKind.USER_INPUT, @@ -489,29 +510,36 @@ def _get_new_signature( # noqa: C901 elif is_tagged: input_specs.append(orig_input_spec) - if orig_input_spec.kind == InputKind.PARAMETER: - new_state_dict[orig_input_spec.target] = ( - original_program.state_dict[orig_input_spec.target] - ) - elif ( - orig_input_spec.kind == InputKind.BUFFER - and orig_input_spec.persistent - ): - new_state_dict[orig_input_spec.target] = ( - original_program.state_dict[orig_input_spec.target] + if orig_input_spec.kind == InputKind.USER_INPUT: + continue + + # The following input specs are all attributes that should be + # consumed by the delegate, so we want to remove it from the + # toplevel module input/output + input_specs_to_delete[node.name] = orig_input_spec + + input_target = orig_input_spec.target + if input_target in original_program.state_dict: + assert orig_input_spec.kind in ( + InputKind.PARAMETER, + InputKind.BUFFER, ) - elif orig_input_spec.kind == InputKind.BUFFER: - assert not orig_input_spec.persistent - new_constants[orig_input_spec.target] = original_program.constants[ - orig_input_spec.target + + new_state_dict[input_target] = original_program.state_dict[ + input_target ] - elif orig_input_spec.kind in ( - InputKind.CONSTANT_TENSOR, - InputKind.CUSTOM_OBJ, - ): - new_constants[orig_input_spec.target] = original_program.constants[ - orig_input_spec.target + elif input_target in original_program.constants: + assert orig_input_spec.kind in ( + InputKind.CONSTANT_TENSOR, + InputKind.CUSTOM_OBJ, + InputKind.BUFFER, + ) + + new_constants[input_target] = original_program.constants[ + input_target ] + else: + raise RuntimeError(f"Invalid input spec {orig_input_spec} received") else: input_specs.append( @@ -525,60 +553,46 @@ def _get_new_signature( # noqa: C901 if node.op == "output": output_nodes = pytree.tree_leaves((node.args, node.kwargs)) - if tag is not None: - # We are constructing output_specs for the delegate outputs. - # These don't have any buffer mutations. - - for output_node in output_nodes: - if not isinstance(output_node, torch.fx.Node): - output_specs.append( - OutputSpec( - kind=OutputKind.USER_OUTPUT, - arg=ConstantArgument(name="", value=output_node), - target=None, - ) - ) - else: - output_specs.append( - OutputSpec( - kind=OutputKind.USER_OUTPUT, - arg=TensorArgument(name=output_node.name), - target=None, - ) + for output_node in output_nodes: + + if not isinstance(output_node, torch.fx.Node): + output_specs.append( + OutputSpec( + kind=OutputKind.USER_OUTPUT, + arg=ConstantArgument(name="", value=output_node), + target=None, ) + ) - else: - # We are reconstruting the toplevel module which contains - # delegates. Delegation should not change the number of outputs - # in the toplevel module, and it does not touch the mutated buffers - - assert len(old_signature.output_specs) == len(output_nodes) - for prev_output_spec, output_node in zip( - old_signature.output_specs, output_nodes - ): - if not isinstance(output_node, torch.fx.Node): - assert isinstance(prev_output_spec.arg, ConstantArgument) - output_specs.append( - OutputSpec( - kind=OutputKind.USER_OUTPUT, - arg=ConstantArgument(name="", value=output_node), - target=None, - ) + else: + output_specs.append( + OutputSpec( + kind=OutputKind.USER_OUTPUT, + arg=TensorArgument(name=output_node.name), + target=None, ) + ) - else: - new_output_spec = copy.deepcopy(prev_output_spec) - new_output_spec.arg.name = output_node.name - output_specs.append(new_output_spec) + new_signature = ExportGraphSignature( + input_specs=input_specs, output_specs=output_specs + ) - return new_signature, new_state_dict, new_constants + return ( + new_signature, + new_state_dict, + new_constants, + input_specs_to_delete, + output_specs_to_delete, + ) def create_exported_program_from_submodule( submodule: torch.fx.GraphModule, owning_program: ExportedProgram, tag: str, -) -> ExportedProgram: + call_module_node: torch.fx.Node, + is_submodule: bool, +) -> Tuple[ExportedProgram, Dict[str, InputSpec], Dict[str, OutputSpec]]: """ Creates an ExportedProgram from the given submodule using the parameters and buffers from the top-level owning program @@ -590,34 +604,52 @@ def create_exported_program_from_submodule( Returns: The ExportedProgram created from submodule + input_specs_to_delete (Dict[str, InputSpec]): The input specs that have + been consumed by the delegate (param/buffer input nodes) and should + be removed from the toplevel ExportedProgram. + output_specs_to_delete (Dict[str, InputSpec]): The output specs that have + been consumed by the delegate (buffer mutation nodes) and should be + removed from the toplevel ExportedProgram. """ # Arrange the submodule's placeholders in order submodule = arrange_graph_placeholders(submodule, owning_program) + # TODO: we probably need to arrange the outputs wrt buffer mutations. + # Get updated graph signature - subgraph_signature, subgraph_state_dict, subgraph_constants = _get_new_signature( - owning_program, submodule, tag + ( + subgraph_signature, + subgraph_state_dict, + subgraph_constants, + toplevel_input_specs_to_delete, + toplevel_output_specs_to_delete, + ) = _get_new_signature( + owning_program, submodule, call_module_node, tag, is_submodule ) in_spec = pytree.tree_flatten((tuple(subgraph_signature.user_inputs), {}))[1] out_spec = pytree.tree_flatten(subgraph_signature.user_outputs)[1] - return ExportedProgram( - root=submodule, - graph=submodule.graph, - graph_signature=subgraph_signature, - state_dict=subgraph_state_dict, - range_constraints=copy.deepcopy(owning_program.range_constraints), - module_call_graph=[ - ModuleCallEntry( - "", - ModuleCallSignature( - inputs=[], outputs=[], in_spec=in_spec, out_spec=out_spec - ), - ) - ], - constants=subgraph_constants, - verifiers=[owning_program.verifier], + return ( + ExportedProgram( + root=submodule, + graph=submodule.graph, + graph_signature=subgraph_signature, + state_dict=subgraph_state_dict, + range_constraints=copy.deepcopy(owning_program.range_constraints), + module_call_graph=[ + ModuleCallEntry( + "", + ModuleCallSignature( + inputs=[], outputs=[], in_spec=in_spec, out_spec=out_spec + ), + ) + ], + constants=subgraph_constants, + verifiers=[owning_program.verifier], + ), + toplevel_input_specs_to_delete, + toplevel_output_specs_to_delete, ) @@ -740,3 +772,61 @@ def get_lowered_backend_modules( lowered_programs.append(lowered_backend_module) return lowered_programs + + +def _unsafe_adjust_original_program( + original_program: ExportedProgram, + call_delegate_node: torch.fx.Node, + input_specs_to_delete: Dict[str, InputSpec], + output_specs_to_delete: Dict[str, OutputSpec], +) -> None: + """ + Directly modify the original exported program's signature and state dict + based on the consumed params/buffers in the delegate. + """ + original_program._graph_signature.input_specs = [ + input_spec + for input_spec in original_program.graph_signature.input_specs + if input_spec.arg.name not in input_specs_to_delete + ] + + currently_used_targets: Set[str] = { + input_spec.target + for input_spec in original_program._graph_signature.input_specs + if input_spec.target is not None + } + + original_program._graph_signature.output_specs = [ + output_spec + for output_spec in original_program.graph_signature.output_specs + if output_spec.arg.name not in output_specs_to_delete + ] + + # Delete all parameters/buffers consumed by the created exported program + # from the graph signature, state dict, constants table + for node in original_program.graph.nodes: + if node.op == "placeholder": + if node.name in input_specs_to_delete: + assert len(node.users) == 0 + original_program.graph.erase_node(node) + else: + break + + for input_spec in input_specs_to_delete.values(): + input_target = input_spec.target + assert input_target is not None + + if input_target in currently_used_targets: + continue + + if input_spec.kind == InputKind.PARAMETER: + del original_program._state_dict[input_target] + elif input_spec.kind == InputKind.BUFFER: + if input_spec.persistent: + del original_program._state_dict[input_target] + else: + del original_program._constants[input_spec.target] + elif input_spec.kind == InputKind.CONSTANT_TENSOR: + del original_program._constants[input_spec.target] + else: + raise RuntimeError(f"Invalid input spec {input_spec} received")