Skip to content

exclude mutated buffer #2946

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Apr 9, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
84 changes: 83 additions & 1 deletion exir/backend/test/test_partitioner.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@
from executorch.exir.backend.test.demos.rpc.executor_backend_preprocess import (
ExecutorBackend,
)
from executorch.exir.backend.utils import get_delegates
from executorch.exir.backend.utils import get_delegates, tag_constant_data

from executorch.exir.dialects._ops import ops as exir_ops

Expand Down Expand Up @@ -523,3 +523,85 @@ def partition(
"constant data node (b_const) is tagged with (tag0) but has user (aten_sub_tensor) which has tag (None)",
str(error.exception),
)

def test_not_delegate_mutable_buffers(self) -> None:
"""
A test case to check the mutated buffer is not delegated. We'll need to add a test case
to consider when the delegate can consume the mutable buffer.
"""

class MutableStateModule(torch.nn.Module):
def __init__(self):
super().__init__()
self.register_buffer("my_state", torch.zeros(1))

def forward(self, x):
y = x + self.my_state
self.my_state.add_(1)
return y

edge = exir.to_edge(
torch.export.export(
MutableStateModule(),
(torch.zeros(1),),
)
)
self.assertGreater(
len(edge.exported_program().graph_signature.buffers_to_mutate),
0,
"The test case should at leaset one mutable buffer",
)

class PartitionerTagData(Partitioner):
def __init__(self):
super().__init__()
self.delegation_spec = DelegationSpec(
ExecutorBackend.__name__,
[CompileSpec(key, value) for key, value in self.spec.items()],
)

def partition(
self, edge_exported_program: ExportedProgram
) -> PartitionResult:
partition_tags = {}
for node in edge_exported_program.graph.nodes:
if node.op == "call_function" and node.target in [
exir_ops.edge.aten.add.Tensor
]:
delegation_tag = "tag0"
node.meta["delegation_tag"] = delegation_tag
partition_tags[delegation_tag] = self.delegation_spec
tag_constant_data(edge_exported_program)
return PartitionResult(
tagged_exported_program=edge_exported_program,
partition_tags=partition_tags,
)

# Check the edge program inital buffers_to_mutate
mutate_op = "aten_add_tensor_1"
self.assertEqual(
edge.exported_program().graph_signature.buffers_to_mutate[mutate_op],
"my_state",
)
edge = edge.to_backend(PartitionerTagData())
# After to_backend, add is delegated and is no longer in buffers_to_mutate.
self.assertNotIn(
mutate_op,
edge.exported_program().graph_signature.buffers_to_mutate,
)

mutate_op = "getitem_1"
# Ensure the mutated buffer is not delegated, and the new mutate node is getitem (from call_delegate)
self.assertEqual(
edge.exported_program().graph_signature.buffers_to_mutate[mutate_op],
"my_state",
)
# Check the copy_ node is inserted
edge = edge.to_executorch()
copy_node = [
node
for node in edge.exported_program().graph.nodes
if node.op == "call_function"
and node.target == torch.ops.aten.copy_.default
]
self.assertEqual(len(copy_node), 1)
43 changes: 29 additions & 14 deletions exir/backend/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -508,27 +508,42 @@ def tag_constant_data(edge_program: ExportedProgram) -> None:
subgraph. Throw error when const/param/buffers is used across different partitions. That is the
underlying data will be owned by multiple delegates.
"""
mutated_buffer = set()
for node in edge_program.graph.nodes:
if node.op == "placeholder" and (
is_param(edge_program, node)
or is_buffer(edge_program, node)
or is_lifted_tensor_constant(edge_program, node)
):
for node_user in node.users:
if node_user.name in edge_program.graph_signature.buffers_to_mutate:
logging.info(
"The buffer node is a mutated buffer node, which is not constant."
)
mutated_buffer.add(node)

for node in edge_program.graph.nodes:
# go through const/param/buffer nodes, if all users of const/param/buffer nodes are partitioned then partition
if node.op == "placeholder" and (
is_param(edge_program, node)
or is_buffer(edge_program, node)
or is_lifted_tensor_constant(edge_program, node)
):
user_tags = set()
for user in node.users:
user_tag = user.meta.get("delegation_tag", None)
if user_tag is not None:
user_tags.add(user_tag)
if len(user_tags) > 1:
logging.info(
f"The data node is used across multiple partitions, including {user_tags}. "
"If the data is too large and it's not preferred to copy, please tag the "
"constant node like node.['no_copy'] = True and they won't be copied."
)
# tag the data node with the same tag as the last user
if len(user_tags) > 0:
node.meta["delegation_tag"] = user_tags.pop()
if node not in mutated_buffer:
user_tags = set()
for user in node.users:
user_tag = user.meta.get("delegation_tag", None)
if user_tag is not None:
user_tags.add(user_tag)
if len(user_tags) > 1:
logging.info(
f"The data node is used across multiple partitions, including {user_tags}. "
"If the data is too large and it's not preferred to copy, please tag the "
"constant node like node.['no_copy'] = True and they won't be copied."
)
# tag the data node with the same tag as the last user
if len(user_tags) > 0:
node.meta["delegation_tag"] = user_tags.pop()


# TODO - style: use templated types
Expand Down
Loading