Skip to content

Commit 599cfde

Browse files
cccclaifacebook-github-bot
authored andcommitted
exclude mutated buffer (#2876)
Summary: Pull Request resolved: #2876 Fixing the tag constant for mutable buffer. The buffer shouldn't be tagged if it's going to be mutated by the delegated. It's more common in hardware backends Will follow up and test having delegate consume mutation Reviewed By: mcr229, angelayi Differential Revision: D55812844 fbshipit-source-id: e0be4c2dc295141d673cccb1aeecee45894b1e70
1 parent dc7e4d5 commit 599cfde

File tree

2 files changed

+112
-15
lines changed

2 files changed

+112
-15
lines changed

exir/backend/test/test_partitioner.py

+83-1
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@
2626
from executorch.exir.backend.test.demos.rpc.executor_backend_preprocess import (
2727
ExecutorBackend,
2828
)
29-
from executorch.exir.backend.utils import get_delegates
29+
from executorch.exir.backend.utils import get_delegates, tag_constant_data
3030

3131
from executorch.exir.dialects._ops import ops as exir_ops
3232

@@ -523,3 +523,85 @@ def partition(
523523
"constant data node (b_const) is tagged with (tag0) but has user (aten_sub_tensor) which has tag (None)",
524524
str(error.exception),
525525
)
526+
527+
def test_not_delegate_mutable_buffers(self) -> None:
528+
"""
529+
A test case to check the mutated buffer is not delegated. We'll need to add a test case
530+
to consider when the delegate can consume the mutable buffer.
531+
"""
532+
533+
class MutableStateModule(torch.nn.Module):
534+
def __init__(self):
535+
super().__init__()
536+
self.register_buffer("my_state", torch.zeros(1))
537+
538+
def forward(self, x):
539+
y = x + self.my_state
540+
self.my_state.add_(1)
541+
return y
542+
543+
edge = exir.to_edge(
544+
torch.export.export(
545+
MutableStateModule(),
546+
(torch.zeros(1),),
547+
)
548+
)
549+
self.assertGreater(
550+
len(edge.exported_program().graph_signature.buffers_to_mutate),
551+
0,
552+
"The test case should at leaset one mutable buffer",
553+
)
554+
555+
class PartitionerTagData(Partitioner):
556+
def __init__(self):
557+
super().__init__()
558+
self.delegation_spec = DelegationSpec(
559+
ExecutorBackend.__name__,
560+
[CompileSpec(key, value) for key, value in self.spec.items()],
561+
)
562+
563+
def partition(
564+
self, edge_exported_program: ExportedProgram
565+
) -> PartitionResult:
566+
partition_tags = {}
567+
for node in edge_exported_program.graph.nodes:
568+
if node.op == "call_function" and node.target in [
569+
exir_ops.edge.aten.add.Tensor
570+
]:
571+
delegation_tag = "tag0"
572+
node.meta["delegation_tag"] = delegation_tag
573+
partition_tags[delegation_tag] = self.delegation_spec
574+
tag_constant_data(edge_exported_program)
575+
return PartitionResult(
576+
tagged_exported_program=edge_exported_program,
577+
partition_tags=partition_tags,
578+
)
579+
580+
# Check the edge program inital buffers_to_mutate
581+
mutate_op = "aten_add_tensor_1"
582+
self.assertEqual(
583+
edge.exported_program().graph_signature.buffers_to_mutate[mutate_op],
584+
"my_state",
585+
)
586+
edge = edge.to_backend(PartitionerTagData())
587+
# After to_backend, add is delegated and is no longer in buffers_to_mutate.
588+
self.assertNotIn(
589+
mutate_op,
590+
edge.exported_program().graph_signature.buffers_to_mutate,
591+
)
592+
593+
mutate_op = "getitem_1"
594+
# Ensure the mutated buffer is not delegated, and the new mutate node is getitem (from call_delegate)
595+
self.assertEqual(
596+
edge.exported_program().graph_signature.buffers_to_mutate[mutate_op],
597+
"my_state",
598+
)
599+
# Check the copy_ node is inserted
600+
edge = edge.to_executorch()
601+
copy_node = [
602+
node
603+
for node in edge.exported_program().graph.nodes
604+
if node.op == "call_function"
605+
and node.target == torch.ops.aten.copy_.default
606+
]
607+
self.assertEqual(len(copy_node), 1)

exir/backend/utils.py

+29-14
Original file line numberDiff line numberDiff line change
@@ -508,27 +508,42 @@ def tag_constant_data(edge_program: ExportedProgram) -> None:
508508
subgraph. Throw error when const/param/buffers is used across different partitions. That is the
509509
underlying data will be owned by multiple delegates.
510510
"""
511+
mutated_buffer = set()
512+
for node in edge_program.graph.nodes:
513+
if node.op == "placeholder" and (
514+
is_param(edge_program, node)
515+
or is_buffer(edge_program, node)
516+
or is_lifted_tensor_constant(edge_program, node)
517+
):
518+
for node_user in node.users:
519+
if node_user.name in edge_program.graph_signature.buffers_to_mutate:
520+
logging.info(
521+
"The buffer node is a mutated buffer node, which is not constant."
522+
)
523+
mutated_buffer.add(node)
524+
511525
for node in edge_program.graph.nodes:
512526
# go through const/param/buffer nodes, if all users of const/param/buffer nodes are partitioned then partition
513527
if node.op == "placeholder" and (
514528
is_param(edge_program, node)
515529
or is_buffer(edge_program, node)
516530
or is_lifted_tensor_constant(edge_program, node)
517531
):
518-
user_tags = set()
519-
for user in node.users:
520-
user_tag = user.meta.get("delegation_tag", None)
521-
if user_tag is not None:
522-
user_tags.add(user_tag)
523-
if len(user_tags) > 1:
524-
logging.info(
525-
f"The data node is used across multiple partitions, including {user_tags}. "
526-
"If the data is too large and it's not preferred to copy, please tag the "
527-
"constant node like node.['no_copy'] = True and they won't be copied."
528-
)
529-
# tag the data node with the same tag as the last user
530-
if len(user_tags) > 0:
531-
node.meta["delegation_tag"] = user_tags.pop()
532+
if node not in mutated_buffer:
533+
user_tags = set()
534+
for user in node.users:
535+
user_tag = user.meta.get("delegation_tag", None)
536+
if user_tag is not None:
537+
user_tags.add(user_tag)
538+
if len(user_tags) > 1:
539+
logging.info(
540+
f"The data node is used across multiple partitions, including {user_tags}. "
541+
"If the data is too large and it's not preferred to copy, please tag the "
542+
"constant node like node.['no_copy'] = True and they won't be copied."
543+
)
544+
# tag the data node with the same tag as the last user
545+
if len(user_tags) > 0:
546+
node.meta["delegation_tag"] = user_tags.pop()
532547

533548

534549
# TODO - style: use templated types

0 commit comments

Comments
 (0)