|
26 | 26 | from executorch.exir.backend.test.demos.rpc.executor_backend_preprocess import (
|
27 | 27 | ExecutorBackend,
|
28 | 28 | )
|
29 |
| -from executorch.exir.backend.utils import get_delegates |
| 29 | +from executorch.exir.backend.utils import get_delegates, tag_constant_data |
30 | 30 |
|
31 | 31 | from executorch.exir.dialects._ops import ops as exir_ops
|
32 | 32 |
|
@@ -523,3 +523,85 @@ def partition(
|
523 | 523 | "constant data node (b_const) is tagged with (tag0) but has user (aten_sub_tensor) which has tag (None)",
|
524 | 524 | str(error.exception),
|
525 | 525 | )
|
| 526 | + |
| 527 | + def test_not_delegate_mutable_buffers(self) -> None: |
| 528 | + """ |
| 529 | + A test case to check the mutated buffer is not delegated. We'll need to add a test case |
| 530 | + to consider when the delegate can consume the mutable buffer. |
| 531 | + """ |
| 532 | + |
| 533 | + class MutableStateModule(torch.nn.Module): |
| 534 | + def __init__(self): |
| 535 | + super().__init__() |
| 536 | + self.register_buffer("my_state", torch.zeros(1)) |
| 537 | + |
| 538 | + def forward(self, x): |
| 539 | + y = x + self.my_state |
| 540 | + self.my_state.add_(1) |
| 541 | + return y |
| 542 | + |
| 543 | + edge = exir.to_edge( |
| 544 | + torch.export.export( |
| 545 | + MutableStateModule(), |
| 546 | + (torch.zeros(1),), |
| 547 | + ) |
| 548 | + ) |
| 549 | + self.assertGreater( |
| 550 | + len(edge.exported_program().graph_signature.buffers_to_mutate), |
| 551 | + 0, |
| 552 | + "The test case should at leaset one mutable buffer", |
| 553 | + ) |
| 554 | + |
| 555 | + class PartitionerTagData(Partitioner): |
| 556 | + def __init__(self): |
| 557 | + super().__init__() |
| 558 | + self.delegation_spec = DelegationSpec( |
| 559 | + ExecutorBackend.__name__, |
| 560 | + [CompileSpec(key, value) for key, value in self.spec.items()], |
| 561 | + ) |
| 562 | + |
| 563 | + def partition( |
| 564 | + self, edge_exported_program: ExportedProgram |
| 565 | + ) -> PartitionResult: |
| 566 | + partition_tags = {} |
| 567 | + for node in edge_exported_program.graph.nodes: |
| 568 | + if node.op == "call_function" and node.target in [ |
| 569 | + exir_ops.edge.aten.add.Tensor |
| 570 | + ]: |
| 571 | + delegation_tag = "tag0" |
| 572 | + node.meta["delegation_tag"] = delegation_tag |
| 573 | + partition_tags[delegation_tag] = self.delegation_spec |
| 574 | + tag_constant_data(edge_exported_program) |
| 575 | + return PartitionResult( |
| 576 | + tagged_exported_program=edge_exported_program, |
| 577 | + partition_tags=partition_tags, |
| 578 | + ) |
| 579 | + |
| 580 | + # Check the edge program inital buffers_to_mutate |
| 581 | + mutate_op = "aten_add_tensor_1" |
| 582 | + self.assertEqual( |
| 583 | + edge.exported_program().graph_signature.buffers_to_mutate[mutate_op], |
| 584 | + "my_state", |
| 585 | + ) |
| 586 | + edge = edge.to_backend(PartitionerTagData()) |
| 587 | + # After to_backend, add is delegated and is no longer in buffers_to_mutate. |
| 588 | + self.assertNotIn( |
| 589 | + mutate_op, |
| 590 | + edge.exported_program().graph_signature.buffers_to_mutate, |
| 591 | + ) |
| 592 | + |
| 593 | + mutate_op = "getitem_1" |
| 594 | + # Ensure the mutated buffer is not delegated, and the new mutate node is getitem (from call_delegate) |
| 595 | + self.assertEqual( |
| 596 | + edge.exported_program().graph_signature.buffers_to_mutate[mutate_op], |
| 597 | + "my_state", |
| 598 | + ) |
| 599 | + # Check the copy_ node is inserted |
| 600 | + edge = edge.to_executorch() |
| 601 | + copy_node = [ |
| 602 | + node |
| 603 | + for node in edge.exported_program().graph.nodes |
| 604 | + if node.op == "call_function" |
| 605 | + and node.target == torch.ops.aten.copy_.default |
| 606 | + ] |
| 607 | + self.assertEqual(len(copy_node), 1) |
0 commit comments