Skip to content

Commit 457f1a3

Browse files
angelayifacebook-github-bot
authored andcommitted
Allow delegate to consume buffer mutations (#4830)
Summary: Pull Request resolved: #4830 Fixing #4209 Edge Program: ``` ExportedProgram: class GraphModule(torch.nn.Module): def forward(self, b_b: "f32[3, 3]", x: "f32[3, 3]"): # File: /data/users/angelayi/fbsource/buck-out/v2/gen/fbcode/389acaeb40d57230/executorch/exir/backend/test/__test_partitioner__/test_partitioner#link-tree/executorch/exir/backend/test/test_partitioner.py:631 in forward, code: self.b.add_(x) aten_add_tensor: "f32[3, 3]" = executorch_exir_dialects_edge__ops_aten_add_Tensor(b_b, x); b_b = None # File: /data/users/angelayi/fbsource/buck-out/v2/gen/fbcode/389acaeb40d57230/executorch/exir/backend/test/__test_partitioner__/test_partitioner#link-tree/executorch/exir/backend/test/test_partitioner.py:632 in forward, code: return x + self.b aten_add_tensor_1: "f32[3, 3]" = executorch_exir_dialects_edge__ops_aten_add_Tensor(x, aten_add_tensor); x = None return (aten_add_tensor, aten_add_tensor_1) Graph signature: ExportGraphSignature(input_specs=[InputSpec(kind=<InputKind.BUFFER: 3>, arg=TensorArgument(name='b_b'), target='b', persistent=True), InputSpec(kind=<InputKind.USER_INPUT: 1>, arg=TensorArgument(name='x'), target=None, persistent=None)], output_specs=[OutputSpec(kind=<OutputKind.BUFFER_MUTATION: 3>, arg=TensorArgument(name='aten_add_tensor'), target='b'), OutputSpec(kind=<OutputKind.USER_OUTPUT: 1>, arg=TensorArgument(name='aten_add_tensor_1'), target=None)]) ``` Partitioned / lowered Exported Program (buffer mutation gets removed): ``` ExportedProgram: class GraphModule(torch.nn.Module): def forward(self, x: "f32[3, 3]"): # No stacktrace found for following nodes lowered_module_0 = self.lowered_module_0 executorch_call_delegate = torch.ops.higher_order.executorch_call_delegate(lowered_module_0, x); lowered_module_0 = x = None # File: /data/users/angelayi/fbsource/buck-out/v2/gen/fbcode/389acaeb40d57230/executorch/exir/backend/test/__test_partitioner__/test_partitioner#link-tree/executorch/exir/backend/test/test_partitioner.py:632 in forward, code: return x + self.b getitem_1: "f32[3, 3]" = executorch_call_delegate[0]; executorch_call_delegate = None return (getitem_1,) Graph signature: ExportGraphSignature(input_specs=[InputSpec(kind=<InputKind.USER_INPUT: 1>, arg=TensorArgument(name='x'), target=None, persistent=None)], output_specs=[OutputSpec(kind=<OutputKind.USER_OUTPUT: 1>, arg=TensorArgument(name='getitem_1'), target=None)]) ``` Delegate (consumes the buffer mutation): ``` ExportedProgram: class GraphModule(torch.nn.Module): def forward(self, b_b: "f32[3, 3]", x: "f32[3, 3]"): # File: /data/users/angelayi/fbsource/buck-out/v2/gen/fbcode/389acaeb40d57230/executorch/exir/backend/test/__test_partitioner__/test_partitioner#link-tree/executorch/exir/backend/test/test_partitioner.py:631 in forward, code: self.b.add_(x) aten_add_tensor: "f32[3, 3]" = executorch_exir_dialects_edge__ops_aten_add_Tensor(b_b, x); b_b = None # File: /data/users/angelayi/fbsource/buck-out/v2/gen/fbcode/389acaeb40d57230/executorch/exir/backend/test/__test_partitioner__/test_partitioner#link-tree/executorch/exir/backend/test/test_partitioner.py:632 in forward, code: return x + self.b aten_add_tensor_1: "f32[3, 3]" = executorch_exir_dialects_edge__ops_aten_add_Tensor(x, aten_add_tensor); x = None return (aten_add_tensor, aten_add_tensor_1) Graph signature: ExportGraphSignature(input_specs=[InputSpec(kind=<InputKind.BUFFER: 3>, arg=TensorArgument(name='b_b'), target='b', persistent=True), InputSpec(kind=<InputKind.USER_INPUT: 1>, arg=TensorArgument(name='x'), target=None, persistent=None)], output_specs=[OutputSpec(kind=<OutputKind.BUFFER_MUTATION: 3>, arg=TensorArgument(name='aten_add_tensor'), target='b'), OutputSpec(kind=<OutputKind.USER_OUTPUT: 1>, arg=TensorArgument(name='aten_add_tensor_1'), target=None)]) ``` Differential Revision: D60838243
1 parent ff46dd5 commit 457f1a3

File tree

4 files changed

+275
-7
lines changed

4 files changed

+275
-7
lines changed

exir/backend/test/TARGETS

+1
Original file line numberDiff line numberDiff line change
@@ -290,6 +290,7 @@ python_unittest(
290290
"//executorch/exir/backend/test/demos/rpc:executor_backend_register",
291291
],
292292
deps = [
293+
":op_partitioner_demo",
293294
"//caffe2:torch",
294295
"//executorch/exir:lib",
295296
"//executorch/exir/backend:backend_details",

exir/backend/test/op_partitioner_demo.py

+50
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,9 @@
2121
from executorch.exir.backend.test.backend_with_compiler_demo import (
2222
BackendWithCompilerDemo,
2323
)
24+
from executorch.exir.backend.test.demos.rpc.executor_backend_preprocess import (
25+
ExecutorBackend,
26+
)
2427
from executorch.exir.dialects._ops import ops as exir_ops
2528
from executorch.exir.graph_module import get_control_flow_submodules
2629
from torch._export.utils import is_buffer, is_lifted_tensor_constant, is_param
@@ -29,6 +32,11 @@
2932
from torch.fx.passes.operator_support import any_chain, OperatorSupportBase
3033

3134

35+
class AllOperatorSupport(OperatorSupportBase):
36+
def is_node_supported(self, submodules, node: torch.fx.Node) -> bool:
37+
return node.op == "call_function"
38+
39+
3240
class AddOperatorSupport(OperatorSupportBase):
3341
def is_node_supported(self, submodules, node: torch.fx.Node) -> bool:
3442
return node.op == "call_function" and node.target in [
@@ -126,6 +134,48 @@ def partition(self, edge_exported_program: ExportedProgram) -> PartitionResult:
126134
)
127135

128136

137+
@final
138+
class AllNodesPartitionerDemo(Partitioner):
139+
"""
140+
Partitions all nodes
141+
"""
142+
143+
def __init__(self) -> None:
144+
self.op_support = AllOperatorSupport()
145+
self.delegation_spec = DelegationSpec(ExecutorBackend.__name__, [])
146+
147+
def partition(self, edge_exported_program: ExportedProgram) -> PartitionResult:
148+
partition_tags = {}
149+
partition_list = generate_pattern_op_partitions(
150+
edge_exported_program.graph_module, op_support=self.op_support
151+
)
152+
for partition in partition_list:
153+
for node in partition.nodes:
154+
delegation_tag = f"tag{partition.id}"
155+
partition_tags[delegation_tag] = self.delegation_spec
156+
157+
# Tag the add nodes
158+
node.meta["delegation_tag"] = delegation_tag
159+
160+
for arg_node in node.args:
161+
if not isinstance(arg_node, torch.fx.Node):
162+
continue
163+
164+
is_get_attr = arg_node.op == "get_attr"
165+
is_param_buffer = arg_node.op == "placeholder" and (
166+
is_param(edge_exported_program, arg_node)
167+
or is_buffer(edge_exported_program, arg_node)
168+
or is_lifted_tensor_constant(edge_exported_program, arg_node)
169+
)
170+
if is_get_attr or is_param_buffer:
171+
arg_node.meta["delegation_tag"] = delegation_tag
172+
# Add to the list of partitioned nodes.
173+
174+
return PartitionResult(
175+
tagged_exported_program=edge_exported_program, partition_tags=partition_tags
176+
)
177+
178+
129179
ops_not_to_decompose = [
130180
torch.ops.aten.linear.default,
131181
torch.ops.aten.scaled_dot_product_attention.default,

exir/backend/test/test_partitioner.py

+109
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@
2626
from executorch.exir.backend.test.demos.rpc.executor_backend_preprocess import (
2727
ExecutorBackend,
2828
)
29+
from executorch.exir.backend.test.op_partitioner_demo import AddAttributePartitionerDemo, AllNodesPartitionerDemo
2930
from executorch.exir.backend.utils import get_delegates, tag_constant_data
3031

3132
from executorch.exir.dialects._ops import ops as exir_ops
@@ -619,3 +620,111 @@ def partition(
619620
and node.target == torch.ops.aten.copy_.default
620621
]
621622
self.assertEqual(len(copy_node), 1)
623+
624+
def test_buffer_mutation1(self):
625+
class TestModule(torch.nn.Module):
626+
def __init__(self):
627+
super().__init__()
628+
self.register_buffer("b", torch.ones(3, 3))
629+
630+
def forward(self, x):
631+
self.b.add_(x)
632+
return x + self.b
633+
634+
model_inputs = (torch.ones(3, 3),)
635+
orig_res = TestModule()(*model_inputs)
636+
edge_program = exir.to_edge(torch.export.export(TestModule(), model_inputs))
637+
lowered = edge_program.to_backend(AddAttributePartitionerDemo())
638+
639+
self.assertTrue(
640+
torch.allclose(lowered.exported_program().module()(*model_inputs), orig_res)
641+
)
642+
643+
self.assertEqual(
644+
len(lowered.exported_program().graph_signature.buffers_to_mutate),
645+
0,
646+
)
647+
lowered_module_nodes = get_delegates(lowered.exported_program().graph)
648+
self.assertEqual(len(lowered_module_nodes), 1)
649+
lowered_module_node = lowered_module_nodes[0]
650+
651+
# get call delegate node
652+
call_delegate_node = list(lowered_module_node.users.keys())[0]
653+
self.assertEqual(len(call_delegate_node.args), 2)
654+
655+
lower_module = getattr(
656+
lowered.exported_program().graph_module, lowered_module_node.name
657+
)
658+
delegated_ep = lower_module.original_module
659+
660+
self.assertEqual(len(delegated_ep.state_dict), 1)
661+
self.assertEqual(len(delegated_ep.graph_signature.buffers_to_mutate), 1)
662+
self.assertEqual(len(delegated_ep.graph_signature.buffers), 1)
663+
664+
def test_buffer_mutation_llama_repro(self):
665+
SHAPE = (2, 3)
666+
667+
class Model(torch.nn.Module):
668+
def __init__(self):
669+
super().__init__()
670+
self.register_buffer("cache", torch.zeros(SHAPE, dtype=torch.float32))
671+
672+
def forward(self, q, k_val, input_pos):
673+
q_T = q.transpose(0, 1)
674+
k = torch.ops.aten.index_put_(self.cache, [input_pos, None], k_val)
675+
attn = k.mm(q_T)
676+
return attn
677+
678+
q = torch.rand(1, 3)
679+
k = torch.rand(1, 3)
680+
example_inputs = (q, k, torch.tensor([1, 1]))
681+
682+
model = Model()
683+
model.eval()
684+
685+
exir_program_aten = torch.export.export(model, example_inputs)
686+
exir_program_aten.module()(*example_inputs)
687+
edge_program_manager = exir.to_edge(exir_program_aten)
688+
lowered = edge_program_manager.to_backend(AllNodesPartitionerDemo())
689+
690+
self.assertEqual(
691+
len(lowered.exported_program().graph_signature.buffers_to_mutate),
692+
0,
693+
)
694+
lowered_module_nodes = get_delegates(lowered.exported_program().graph)
695+
self.assertEqual(len(lowered_module_nodes), 1)
696+
lowered_module_node = lowered_module_nodes[0]
697+
698+
# get call delegate node
699+
call_delegate_node = list(lowered_module_node.users.keys())[0]
700+
self.assertEqual(len(call_delegate_node.args), 4)
701+
702+
lower_module = getattr(
703+
lowered.exported_program().graph_module, lowered_module_node.name
704+
)
705+
delegated_ep = lower_module.original_module
706+
707+
self.assertEqual(len(delegated_ep.state_dict), 1)
708+
self.assertEqual(len(delegated_ep.graph_signature.buffers_to_mutate), 1)
709+
self.assertEqual(len(delegated_ep.graph_signature.buffers), 1)
710+
711+
def test_buffer_mutation_unsupported(self):
712+
SHAPE = (2, 3)
713+
714+
class Model(torch.nn.Module):
715+
def __init__(self):
716+
super().__init__()
717+
self.register_buffer("state_1", torch.zeros(SHAPE, dtype=torch.float32))
718+
719+
def forward(self, x):
720+
add = self.state_1.add_(x)
721+
return add
722+
723+
model = Model()
724+
model.eval()
725+
726+
example_inputs = (torch.randn(SHAPE),)
727+
exir_program_aten = torch.export.export(model, example_inputs)
728+
edge_program_manager = exir.to_edge(exir_program_aten)
729+
with self.assertRaises(AssertionError):
730+
edge_program_manager.to_backend(AddAttributePartitionerDemo())

exir/lowered_backend_module.py

+115-7
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88

99
import copy
1010
import operator
11+
from collections import defaultdict
1112
from typing import Any, Dict, List, Optional, Set, Tuple, Union
1213

1314
import torch
@@ -488,8 +489,12 @@ def _get_new_signature( # noqa: C901
488489
else {}
489490
)
490491

492+
toplevel_output_node_to_sig: Dict[str, List[OutputSpec]] = defaultdict(list)
493+
if not is_submodule:
494+
for output_spec in old_signature.output_specs:
495+
toplevel_output_node_to_sig[output_spec.arg.name].append(output_spec)
496+
491497
for node in gm.graph.nodes:
492-
is_tagged = tag is None or node.meta.get("delegation_tag", None) == tag
493498
if node.op == "placeholder":
494499

495500
if node.name not in input_node_to_sig:
@@ -507,7 +512,7 @@ def _get_new_signature( # noqa: C901
507512
if not isinstance(orig_input_spec.arg, TensorArgument):
508513
input_specs.append(orig_input_spec)
509514

510-
elif is_tagged:
515+
elif node.meta.get("delegation_tag", None) == tag:
511516
input_specs.append(orig_input_spec)
512517

513518
if orig_input_spec.kind == InputKind.USER_INPUT:
@@ -551,11 +556,67 @@ def _get_new_signature( # noqa: C901
551556
)
552557

553558
if node.op == "output":
554-
output_nodes = pytree.tree_leaves((node.args, node.kwargs))
555-
556-
for output_node in output_nodes:
559+
buffer_mutation_idxs: Dict[int, List[OutputSpec]] = defaultdict(list)
560+
for user in call_module_node.users.keys():
561+
if user.name in toplevel_output_node_to_sig:
562+
assert (
563+
user.op == "call_function" and user.target == operator.getitem
564+
), f"Invalid user {user}, node.op is {user.op} and node.target is {user.target}"
565+
getitem_idx = user.args[1]
566+
assert isinstance(
567+
getitem_idx, int
568+
), f"Invalid getitem type: {type(getitem_idx)}"
569+
buffer_mutation_idxs[getitem_idx].extend(
570+
toplevel_output_node_to_sig[user.name]
571+
)
557572

558-
if not isinstance(output_node, torch.fx.Node):
573+
for i, output_node in enumerate(node.args[0]):
574+
if i in buffer_mutation_idxs:
575+
assert isinstance(output_node, torch.fx.Node)
576+
orig_output_specs = buffer_mutation_idxs[i]
577+
578+
for orig_output_spec in orig_output_specs:
579+
580+
if (
581+
orig_output_spec.kind == OutputKind.BUFFER_MUTATION
582+
and orig_output_spec.target in new_state_dict
583+
):
584+
# If the delegate wants to consume the buffer, then
585+
# the delegate should also consume the buffer
586+
# mutation (output spec would be a BUFFER_MUTATION).
587+
# Otherwise the delegate will just return the result
588+
# of the mutation as a USER_OUTPUT.
589+
590+
assert len(orig_output_specs) == 1, (
591+
f"Constant {orig_output_spec.target} was tagged to be "
592+
"consumed by the buffer, and was found to also contain "
593+
"a buffer mutation. However this buffer mutation node "
594+
"was found to also be used as other types of outputs "
595+
"which is currently not supported. Please file an "
596+
"issue on Github. \n\n"
597+
f"The toplevel program: {original_program}\n"
598+
)
599+
output_specs.append(
600+
OutputSpec(
601+
kind=OutputKind.BUFFER_MUTATION,
602+
arg=TensorArgument(name=output_node.name),
603+
target=orig_output_spec.target,
604+
)
605+
)
606+
output_specs_to_delete[orig_output_spec.arg.name] = (
607+
orig_output_spec
608+
)
609+
610+
else:
611+
output_specs.append(
612+
OutputSpec(
613+
kind=OutputKind.USER_OUTPUT,
614+
arg=TensorArgument(name=output_node.name),
615+
target=None,
616+
)
617+
)
618+
619+
elif not isinstance(output_node, torch.fx.Node):
559620
output_specs.append(
560621
OutputSpec(
561622
kind=OutputKind.USER_OUTPUT,
@@ -774,7 +835,7 @@ def get_lowered_backend_modules(
774835
return lowered_programs
775836

776837

777-
def _unsafe_adjust_original_program(
838+
def _unsafe_adjust_original_program( # noqa: C901
778839
original_program: ExportedProgram,
779840
call_delegate_node: torch.fx.Node,
780841
input_specs_to_delete: Dict[str, InputSpec],
@@ -830,3 +891,50 @@ def _unsafe_adjust_original_program(
830891
del original_program._constants[input_spec.target]
831892
else:
832893
raise RuntimeError(f"Invalid input spec {input_spec} received")
894+
895+
# Delete buffer mutations from the output which were consumed by the delegate
896+
toplevel_output_node = None
897+
for node in reversed(original_program.graph.nodes):
898+
if node.op == "output":
899+
toplevel_output_node = node
900+
break
901+
902+
assert toplevel_output_node is not None
903+
assert (
904+
len(toplevel_output_node.args) == 1
905+
), f"Invalid output node: {toplevel_output_node} with args {toplevel_output_node.args}"
906+
907+
new_output_args = [
908+
arg
909+
for arg in toplevel_output_node.args[0]
910+
if not isinstance(arg, torch.fx.Node) or arg.name not in output_specs_to_delete
911+
]
912+
toplevel_output_node.args = (tuple(new_output_args),)
913+
914+
# Delete the buffer mutation getitem nodes
915+
getitem_idxs: List[int] = []
916+
user_nodes = list(call_delegate_node.users.keys())
917+
for user in user_nodes:
918+
if user.name in output_specs_to_delete:
919+
assert (
920+
user.op == "call_function" and user.target == operator.getitem
921+
), f"Invalid user {user}, node.op is {node.op} and node.target is {node.target}"
922+
user_idx = user.args[1]
923+
assert isinstance(user_idx, int), f"Invalid getitem type: {type(user_idx)}"
924+
getitem_idxs.append(user_idx)
925+
original_program.graph.erase_node(user)
926+
927+
getitem_idxs.sort(reverse=True)
928+
929+
# Adjust all the getitem indices after the deleted getitems
930+
user_nodes = list(call_delegate_node.users.keys())
931+
for user in user_nodes:
932+
assert user.op == "call_function" and user.target == operator.getitem
933+
user_idx = user.args[1]
934+
assert isinstance(user_idx, int)
935+
for i, idx in enumerate(getitem_idxs):
936+
if user_idx > idx:
937+
user.args = (user.args[0], user_idx - (len(getitem_idxs) - i))
938+
break
939+
940+
original_program._validate()

0 commit comments

Comments
 (0)