diff --git a/exir/backend/canonical_partitioners/TARGETS b/exir/backend/canonical_partitioners/TARGETS index 22a6e2c51bd..8d3e28968b3 100644 --- a/exir/backend/canonical_partitioners/TARGETS +++ b/exir/backend/canonical_partitioners/TARGETS @@ -7,6 +7,7 @@ runtime.python_library( srcs = [ "duplicate_dequant_node_pass.py", "pattern_op_partitioner.py", + "all_node_partitioner.py", ], visibility = [ "//executorch/...", diff --git a/exir/backend/canonical_partitioners/all_node_partitioner.py b/exir/backend/canonical_partitioners/all_node_partitioner.py new file mode 100644 index 00000000000..bc45f2b5239 --- /dev/null +++ b/exir/backend/canonical_partitioners/all_node_partitioner.py @@ -0,0 +1,55 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +from typing import Dict, List + +import torch +from executorch.exir.backend.backend_details import ExportedProgram +from executorch.exir.backend.compile_spec_schema import CompileSpec +from executorch.exir.backend.partitioner import ( + DelegationSpec, + Partitioner, + PartitionResult, +) +from torch._export.utils import is_buffer, is_lifted_tensor_constant, is_param + + +def is_non_tensor_placeholder(node: torch.fx.Node, ep: ExportedProgram) -> bool: + """ + Returns true if the node is a placeholder node and it is not a tensor + """ + return node.op == "placeholder" and not ( + is_param(ep, node) or is_buffer(ep, node) or is_lifted_tensor_constant(ep, node) + ) + + +class AllNodePartitioner(Partitioner): + def __init__( + self, + backend_id: str, + compile_specs: List[CompileSpec], + ): + """ + Partitioner that lowers every single node in the graph module to the + specified backend_id + """ + super().__init__() + self.delegation_spec = DelegationSpec(backend_id, compile_specs) + + def partition(self, exported_program: ExportedProgram) -> PartitionResult: + # tag all nodes + partition_tags: Dict[str, DelegationSpec] = {} + for node in exported_program.graph_module.graph.nodes: + if is_non_tensor_placeholder(node, exported_program) or node.op == "output": + continue + + delegation_tag = self.delegation_spec.backend_id + node.meta["delegation_tag"] = delegation_tag + partition_tags[delegation_tag] = self.delegation_spec + + return PartitionResult( + tagged_exported_program=exported_program, partition_tags=partition_tags + ) diff --git a/exir/backend/test/test_backends.py b/exir/backend/test/test_backends.py index d2bcfa31676..b5a38d875c2 100644 --- a/exir/backend/test/test_backends.py +++ b/exir/backend/test/test_backends.py @@ -10,7 +10,11 @@ import executorch.exir as exir import torch +from executorch.exir import to_edge from executorch.exir.backend.backend_api import LoweredBackendModule, to_backend +from executorch.exir.backend.canonical_partitioners.all_node_partitioner import ( + AllNodePartitioner, +) from executorch.exir.backend.compile_spec_schema import CompileSpec from executorch.exir.backend.partitioner import ( DelegationSpec, @@ -1266,3 +1270,178 @@ def forward(self, x: List[torch.Tensor]): gm = exir.capture(ComposedM(), inputs, exir.CaptureConfig()).to_edge() gm(*inputs) + + def test_to_backend_delegation_spec(self): + class SinModule(torch.nn.Module): + def __init__(self): + super().__init__() + + def forward(self, x): + return [torch.sin(x)] + + sin_module = SinModule() + model_inputs = (torch.ones(1),) + max_value = model_inputs[0].shape[0] + + partitioner = AllNodePartitioner( + "BackendWithCompilerDemo", [CompileSpec("max_value", bytes([max_value]))] + ) + + edgeir_m = to_edge(torch.export.export(sin_module, model_inputs)) + edgeir_m = edgeir_m.to_backend(partitioner) + exec_prog = edgeir_m.to_executorch() + graph_module = exec_prog.exported_program().graph_module + # Check that there is not an aten.sin node. + self.assertTrue( + exir_ops.edge.aten.sin + not in {node.target for node in graph_module.graph.nodes} + ) + + # Check that there exists a call_delegate, representing the call to the + # delegated function + FileCheck().check("torch.ops.higher_order.executorch_call_delegate").run( + graph_module.code + ) + lowered_submodules = get_lowered_submodules(graph_module) + self.assertEqual(len(lowered_submodules), 1) + + for node in graph_module.graph.nodes: + if node.op == "call_function" and node.target == executorch_call_delegate: + # Check that first arg is lowered_module_{unique_id} + self.assertEqual(node.args[0].target, "lowered_module_0") + + program = exec_prog.executorch_program + + # Check the program can be printed + print_program(program) + + # Check the backend delegate + self.check_backend_delegate( + program=program, + delegate=program.execution_plan[0].delegates[0], + expected_id=BackendWithCompilerDemo.__name__, + expected_processed=b"1version:0#op:demo::aten.sin.default, numel:1, dtype:torch.float322#", + ) + + # Check the delegate instruction + self.assertTrue( + isinstance( + program.execution_plan[0].chains[0].instructions[0].instr_args, + DelegateCall, + ) + ) + buff = exec_prog.buffer + + executorch_module = _load_for_executorch_from_buffer(buff) + model_inputs = torch.ones(1) + model_outputs = executorch_module.forward([model_inputs]) + self.assertEqual( + model_inputs, + torch.ones(1), + ) + expected_output = 0.8333 * torch.ones(1) + + self.assertTrue( + torch.allclose(model_outputs[0], expected_output, atol=1e-03, rtol=1e-03) + ) + + def test_to_backend_multimethod_delegation_spec(self): + class SinModule(torch.nn.Module): + def __init__(self): + super().__init__() + + def forward(self, x): + return torch.sin(x) + + def inputs(self): + return (torch.ones(1),) + + class AddMulModule(torch.nn.Module): + def __init__(self): + super().__init__() + + def forward(self, a, x, b): + y = torch.mm(a, x) + z = torch.add(y, b) + return z + + def inputs(self): + return (torch.ones(2, 2), 2 * torch.ones(2, 2), 3 * torch.ones(2, 2)) + + sin_module = SinModule() + max_value_sin = sin_module.inputs()[0].shape[0] + sin_partitioner = AllNodePartitioner( + "BackendWithCompilerDemo", + [CompileSpec("max_value", bytes([max_value_sin]))], + ) + + add_mul_module = AddMulModule() + max_value_add_mul = add_mul_module.inputs()[0].shape[0] + add_mul_partitioner = AllNodePartitioner( + "BackendWithCompilerDemo", + [CompileSpec("max_value", bytes([max_value_add_mul]))], + ) + + edgeir_m = to_edge( + { + "sin": torch.export.export(sin_module, sin_module.inputs()), + "add_mul": torch.export.export(add_mul_module, add_mul_module.inputs()), + } + ) + edgeir_m = edgeir_m.to_backend( + { + "sin": sin_partitioner, + "add_mul": add_mul_partitioner, + } + ) + exec_prog = edgeir_m.to_executorch() + + for method_name in ["sin", "add_mul"]: + graph_module = exec_prog.exported_program(method_name).graph_module + # Check delegated nodes are gone + self.assertTrue( + exir_ops.edge.aten.sin + not in {node.target for node in graph_module.graph.nodes} + ) + self.assertTrue( + exir_ops.edge.aten.add + not in {node.target for node in graph_module.graph.nodes} + ) + self.assertTrue( + exir_ops.edge.aten.mm + not in {node.target for node in graph_module.graph.nodes} + ) + # Check that there exists a call_delegate, representing the call to the + # delegated function + FileCheck().check("torch.ops.higher_order.executorch_call_delegate").run( + graph_module.code + ) + lowered_submodules = get_lowered_submodules(graph_module) + self.assertEqual(len(lowered_submodules), 1) + + program = exec_prog.executorch_program + + # Check the program can be printed + print_program(program) + + buff = exec_prog.buffer + + executorch_module = _load_for_executorch_from_buffer(buff) + + for method_name, module in { + "sin": sin_module, + "add_mul": add_mul_module, + }.items(): + inputs_flattened, _ = tree_flatten(module.inputs()) + model_outputs = executorch_module.run_method( + method_name, tuple(inputs_flattened) + ) + + if method_name == "sin": + # backend with compiler demo does a taylor approximation of sin + ref_output = 0.8333 * torch.ones(1) + else: + ref_output = module(*module.inputs()) + self.assertTrue( + torch.allclose(model_outputs[0], ref_output, atol=1e-03, rtol=1e-03) + ) diff --git a/exir/backend/test/test_backends_lifted.py b/exir/backend/test/test_backends_lifted.py index 3c55bebd320..be9527b8ccb 100644 --- a/exir/backend/test/test_backends_lifted.py +++ b/exir/backend/test/test_backends_lifted.py @@ -11,6 +11,9 @@ import torch from executorch.exir import to_edge from executorch.exir.backend.backend_api import LoweredBackendModule, to_backend +from executorch.exir.backend.canonical_partitioners.all_node_partitioner import ( + AllNodePartitioner, +) from executorch.exir.backend.compile_spec_schema import CompileSpec from executorch.exir.backend.partitioner import ( DelegationSpec, @@ -138,6 +141,18 @@ def forward(self, x): self.assertTrue(torch.allclose(new_res, expected_res)) + # Test same flow but through edge_program_manager + edgeir_m = to_edge(export(sin_module, model_inputs, strict=True)) + loweredir_m = edgeir_m.to_backend( + AllNodePartitioner(BackendWithCompilerDemo.__name__, []) + ) + lowered_sin_module = get_lowered_submodules( + loweredir_m.exported_program().graph_module + )[0][1] + + new_res = lowered_sin_module(*model_inputs)[0] + + self.assertTrue(torch.allclose(new_res, expected_res)) # TODO(tkaruturi): emitting single LoweredBackendModule # program = to_edge(export(graph_module)).to_exectorch()._emitter_output.program diff --git a/exir/backend/test/test_compatibility.py b/exir/backend/test/test_compatibility.py index 9d87aa5be0e..bcda1d36516 100644 --- a/exir/backend/test/test_compatibility.py +++ b/exir/backend/test/test_compatibility.py @@ -10,6 +10,9 @@ from executorch.exir import to_edge from executorch.exir._serialize import _serialize_pte_binary from executorch.exir.backend.backend_api import to_backend +from executorch.exir.backend.canonical_partitioners.all_node_partitioner import ( + AllNodePartitioner, +) from executorch.exir.backend.compile_spec_schema import CompileSpec from executorch.exir.backend.test.backend_with_compiler_demo import ( BackendWithCompilerDemo, @@ -65,3 +68,49 @@ def forward(self, x): "loading method forward failed with error 0x30", ): executorch_module = _load_for_executorch_from_buffer(buff) + + def test_compatibility_in_runtime_edge_program_manager(self): + class SinModule(torch.nn.Module): + def __init__(self): + super().__init__() + + def forward(self, x): + return torch.sin(x) + + sin_module = SinModule() + model_inputs = (torch.ones(1),) + edgeir_m = to_edge(export(sin_module, model_inputs, strict=True)) + max_value = model_inputs[0].shape[0] + compile_specs = [CompileSpec("max_value", bytes([max_value]))] + lowered_edge_irm = edgeir_m.to_backend( + AllNodePartitioner("BackendWithCompilerDemo", compile_specs) + ) + exec_prog = lowered_edge_irm.to_executorch() + + buff = exec_prog.buffer + + # The demo backend works well + executorch_module = _load_for_executorch_from_buffer(buff) + model_inputs = torch.ones(1) + _ = executorch_module.forward([model_inputs]) + + prog = exec_prog.executorch_program + # Rewrite the delegate version number from 0 to 1. + prog.backend_delegate_data[0].data = bytes( + "1version:1#op:demo::aten.sin.default, numel:1, dtype:torch.float321#", + encoding="utf8", + ) + + # Generate the .pte file with the wrong version. + buff = bytes( + _serialize_pte_binary( + program=prog, + ) + ) + + # Throw runtime error with error code 0x30, meaning delegate is incompatible. + with self.assertRaisesRegex( + RuntimeError, + "loading method forward failed with error 0x30", + ): + executorch_module = _load_for_executorch_from_buffer(buff) diff --git a/exir/backend/test/test_delegate_map_builder.py b/exir/backend/test/test_delegate_map_builder.py index 827cb8cdebc..fcd23b110b6 100644 --- a/exir/backend/test/test_delegate_map_builder.py +++ b/exir/backend/test/test_delegate_map_builder.py @@ -9,12 +9,17 @@ import torch from executorch import exir +from executorch.exir import to_edge from executorch.exir.backend.backend_api import to_backend +from executorch.exir.backend.canonical_partitioners.all_node_partitioner import ( + AllNodePartitioner, +) from executorch.exir.backend.test.backend_with_delegate_mapping_demo import ( BackendWithDelegateMappingDemo, ) from executorch.exir.backend.utils import DelegateMappingBuilder +from executorch.exir.lowered_backend_module import get_lowered_submodules class TestDelegateMapBuilder(unittest.TestCase): diff --git a/exir/program/TARGETS b/exir/program/TARGETS index 33e417e7326..911c33ec692 100644 --- a/exir/program/TARGETS +++ b/exir/program/TARGETS @@ -31,6 +31,7 @@ python_library( "//executorch/exir/_serialize:lib", "//executorch/exir/backend:backend_api", "//executorch/exir/backend:partitioner", + "//executorch/exir/backend/canonical_partitioners:canonical_partitioner_lib", "//executorch/exir/capture:config", "//executorch/exir/emit:emit", "//executorch/exir/emit:lib", diff --git a/exir/program/_program.py b/exir/program/_program.py index 7a2120f9e9b..2d72b4f406f 100644 --- a/exir/program/_program.py +++ b/exir/program/_program.py @@ -24,7 +24,10 @@ from executorch.exir._serialize.data_serializer import DataSerializer from executorch.exir._warnings import experimental from executorch.exir.backend.backend_api import to_backend -from executorch.exir.backend.partitioner import Partitioner +from executorch.exir.backend.canonical_partitioners.all_node_partitioner import ( + AllNodePartitioner, +) +from executorch.exir.backend.partitioner import DelegationSpec, Partitioner from executorch.exir.capture._config import EdgeCompileConfig, ExecutorchBackendConfig from executorch.exir.delegate import executorch_call_delegate, is_lowered_module from executorch.exir.emit import emit_program, EmitterOutput @@ -1439,7 +1442,13 @@ def transform( @et_logger("to_backend") def to_backend( - self, partitioner: Union[Partitioner, Dict[str, Partitioner]] + self, + partitioner: Union[ + DelegationSpec, + Dict[str, DelegationSpec], + Partitioner, + Dict[str, Partitioner], + ], ) -> "EdgeProgramManager": """ Returns a semantically-equivalent program to the one given as input, @@ -1447,12 +1456,15 @@ def to_backend( for delegation as determined by the partitioner. Args: - partitioner: The partitioner can either be a Partitioner subclass instance, or a - dictionary mapping method names to Partitioner subclass instance. If it is a - Partitioner subclass, all programs in the given EdgeProgramManager - will be lowered using the given partitioner. If it is a - dictionary, only method names specified in the dictionary will be - lowered with the given partitioner. + partitioner: The partitioner can be: + - Partitioner Subclass Instance; all programs in the EdgeProgramManager are lowered with + this partitioner + - Dictionary mapping method name to partitioner subclass instance; Only method names specified + in the dictionary will be lowered by the given partitioner. + - DelegationSpec; All programs are completely lowered to the backend_id specified in the + DelegationSpec + - Dictionary mapping method name to DelegationSpec; Only method names specified in the dictionary + will be lowered to the backend_id specified in the DelegationSpec The Partitioner subclass instance is in charge with tagging portions of the input program for delegation. A valid partitioner must return PartitionerResult including valid @@ -1468,13 +1480,19 @@ def to_backend( if isinstance(partitioner, dict): for name, program in self._edge_programs.items(): if name in partitioner.keys(): - new_edge_programs[name] = to_backend(program, partitioner[name]) + partitioner_to_use = partitioner[name] + if isinstance(partitioner_to_use, DelegationSpec): + partitioner_to_use = AllNodePartitioner(partitioner_to_use) + new_edge_programs[name] = to_backend(program, partitioner_to_use) else: new_edge_programs[name] = program else: # apply partitioner to every method for name, program in self._edge_programs.items(): - new_edge_programs[name] = to_backend(program, partitioner) + partitioner_to_use = partitioner + if isinstance(partitioner, DelegationSpec): + partitioner_to_use = AllNodePartitioner(partitioner) + new_edge_programs[name] = to_backend(program, partitioner_to_use) config = EdgeCompileConfig(_check_ir_validity=False) return EdgeProgramManager(