From 00db80777876d81d476acb14781afdb1e07a1cf0 Mon Sep 17 00:00:00 2001 From: Chen Lai Date: Mon, 8 Apr 2024 13:55:49 -0700 Subject: [PATCH] exclude mutated buffer (#2876) Summary: Pull Request resolved: https://github.com/pytorch/executorch/pull/2876 Fixing the tag constant for mutable buffer. The buffer shouldn't be tagged if it's going to be mutated by the delegated. It's more common in hardware backends Will follow up and test having delegate consume mutation Reviewed By: mcr229, angelayi Differential Revision: D55812844 fbshipit-source-id: e0be4c2dc295141d673cccb1aeecee45894b1e70 (cherry picked from commit 599cfde9f542d43587ea0b1de03bffc73f7488eb) --- exir/backend/test/test_partitioner.py | 84 ++++++++++++++++++++++++++- exir/backend/utils.py | 43 +++++++++----- 2 files changed, 112 insertions(+), 15 deletions(-) diff --git a/exir/backend/test/test_partitioner.py b/exir/backend/test/test_partitioner.py index 74974d16231..d492c291f34 100644 --- a/exir/backend/test/test_partitioner.py +++ b/exir/backend/test/test_partitioner.py @@ -26,7 +26,7 @@ from executorch.exir.backend.test.demos.rpc.executor_backend_preprocess import ( ExecutorBackend, ) -from executorch.exir.backend.utils import get_delegates +from executorch.exir.backend.utils import get_delegates, tag_constant_data from executorch.exir.dialects._ops import ops as exir_ops @@ -523,3 +523,85 @@ def partition( "constant data node (b_const) is tagged with (tag0) but has user (aten_sub_tensor) which has tag (None)", str(error.exception), ) + + def test_not_delegate_mutable_buffers(self) -> None: + """ + A test case to check the mutated buffer is not delegated. We'll need to add a test case + to consider when the delegate can consume the mutable buffer. + """ + + class MutableStateModule(torch.nn.Module): + def __init__(self): + super().__init__() + self.register_buffer("my_state", torch.zeros(1)) + + def forward(self, x): + y = x + self.my_state + self.my_state.add_(1) + return y + + edge = exir.to_edge( + torch.export.export( + MutableStateModule(), + (torch.zeros(1),), + ) + ) + self.assertGreater( + len(edge.exported_program().graph_signature.buffers_to_mutate), + 0, + "The test case should at leaset one mutable buffer", + ) + + class PartitionerTagData(Partitioner): + def __init__(self): + super().__init__() + self.delegation_spec = DelegationSpec( + ExecutorBackend.__name__, + [CompileSpec(key, value) for key, value in self.spec.items()], + ) + + def partition( + self, edge_exported_program: ExportedProgram + ) -> PartitionResult: + partition_tags = {} + for node in edge_exported_program.graph.nodes: + if node.op == "call_function" and node.target in [ + exir_ops.edge.aten.add.Tensor + ]: + delegation_tag = "tag0" + node.meta["delegation_tag"] = delegation_tag + partition_tags[delegation_tag] = self.delegation_spec + tag_constant_data(edge_exported_program) + return PartitionResult( + tagged_exported_program=edge_exported_program, + partition_tags=partition_tags, + ) + + # Check the edge program inital buffers_to_mutate + mutate_op = "aten_add_tensor_1" + self.assertEqual( + edge.exported_program().graph_signature.buffers_to_mutate[mutate_op], + "my_state", + ) + edge = edge.to_backend(PartitionerTagData()) + # After to_backend, add is delegated and is no longer in buffers_to_mutate. + self.assertNotIn( + mutate_op, + edge.exported_program().graph_signature.buffers_to_mutate, + ) + + mutate_op = "getitem_1" + # Ensure the mutated buffer is not delegated, and the new mutate node is getitem (from call_delegate) + self.assertEqual( + edge.exported_program().graph_signature.buffers_to_mutate[mutate_op], + "my_state", + ) + # Check the copy_ node is inserted + edge = edge.to_executorch() + copy_node = [ + node + for node in edge.exported_program().graph.nodes + if node.op == "call_function" + and node.target == torch.ops.aten.copy_.default + ] + self.assertEqual(len(copy_node), 1) diff --git a/exir/backend/utils.py b/exir/backend/utils.py index f4c1c28f8bd..b299ba4be8a 100644 --- a/exir/backend/utils.py +++ b/exir/backend/utils.py @@ -508,6 +508,20 @@ def tag_constant_data(edge_program: ExportedProgram) -> None: subgraph. Throw error when const/param/buffers is used across different partitions. That is the underlying data will be owned by multiple delegates. """ + mutated_buffer = set() + for node in edge_program.graph.nodes: + if node.op == "placeholder" and ( + is_param(edge_program, node) + or is_buffer(edge_program, node) + or is_lifted_tensor_constant(edge_program, node) + ): + for node_user in node.users: + if node_user.name in edge_program.graph_signature.buffers_to_mutate: + logging.info( + "The buffer node is a mutated buffer node, which is not constant." + ) + mutated_buffer.add(node) + for node in edge_program.graph.nodes: # go through const/param/buffer nodes, if all users of const/param/buffer nodes are partitioned then partition if node.op == "placeholder" and ( @@ -515,20 +529,21 @@ def tag_constant_data(edge_program: ExportedProgram) -> None: or is_buffer(edge_program, node) or is_lifted_tensor_constant(edge_program, node) ): - user_tags = set() - for user in node.users: - user_tag = user.meta.get("delegation_tag", None) - if user_tag is not None: - user_tags.add(user_tag) - if len(user_tags) > 1: - logging.info( - f"The data node is used across multiple partitions, including {user_tags}. " - "If the data is too large and it's not preferred to copy, please tag the " - "constant node like node.['no_copy'] = True and they won't be copied." - ) - # tag the data node with the same tag as the last user - if len(user_tags) > 0: - node.meta["delegation_tag"] = user_tags.pop() + if node not in mutated_buffer: + user_tags = set() + for user in node.users: + user_tag = user.meta.get("delegation_tag", None) + if user_tag is not None: + user_tags.add(user_tag) + if len(user_tags) > 1: + logging.info( + f"The data node is used across multiple partitions, including {user_tags}. " + "If the data is too large and it's not preferred to copy, please tag the " + "constant node like node.['no_copy'] = True and they won't be copied." + ) + # tag the data node with the same tag as the last user + if len(user_tags) > 0: + node.meta["delegation_tag"] = user_tags.pop() # TODO - style: use templated types