|
21 | 21 | from executorch.backends.cadence.aot.quantizer.quantizer import CadenceDefaultQuantizer
|
22 | 22 | from executorch.backends.cadence.aot.remove_ops import (
|
23 | 23 | RemoveAliasCopyOpPass,
|
| 24 | + RemoveCatFromSliceCopyPass, |
24 | 25 | RemoveCloneOpPass,
|
25 | 26 | RemoveContiguousOpPass,
|
26 | 27 | RemoveDetachCopyPass,
|
@@ -709,3 +710,54 @@ def forward(self, x):
|
709 | 710 | self.assertEqual(
|
710 | 711 | count_node(graph_module, exir_ops.edge.aten.permute_copy.default), 2
|
711 | 712 | )
|
| 713 | + |
| 714 | + def test_remove_cat_from_slice_copy_all_removal(self) -> None: |
| 715 | + class M(torch.nn.Module): |
| 716 | + def __init__(self): |
| 717 | + super().__init__() |
| 718 | + |
| 719 | + def forward(self, x, y): |
| 720 | + x1 = torch.cat((x, y), 0) # (2, 4) |
| 721 | + return torch.slice_copy(x1, dim=0, start=0, end=1) |
| 722 | + |
| 723 | + inputs = tuple(torch.randn(2, 4) for _ in range(2)) |
| 724 | + graph_module = export_to_edge(M(), inputs).exported_program().graph_module |
| 725 | + p = RemoveCatFromSliceCopyPass() |
| 726 | + graph_module = cast(PassResult, p(graph_module)).graph_module |
| 727 | + |
| 728 | + # Ensure both cat nodes were removed |
| 729 | + self.assertEqual(count_node(graph_module, exir_ops.edge.aten.cat.default), 0) |
| 730 | + |
| 731 | + def test_remove_cat_from_slice_copy_no_removal(self) -> None: |
| 732 | + class M(torch.nn.Module): |
| 733 | + def __init__(self): |
| 734 | + super().__init__() |
| 735 | + |
| 736 | + def forward(self, x, y): |
| 737 | + x1 = torch.cat((x, y), 0) # (2, 4) |
| 738 | + return torch.slice_copy(x1, dim=0, start=0, end=3) |
| 739 | + |
| 740 | + inputs = tuple(torch.randn(2, 4) for _ in range(2)) |
| 741 | + graph_module = export_to_edge(M(), inputs).exported_program().graph_module |
| 742 | + p = RemoveCatFromSliceCopyPass() |
| 743 | + graph_module = cast(PassResult, p(graph_module)).graph_module |
| 744 | + |
| 745 | + # Ensure both cat nodes were removed |
| 746 | + self.assertEqual(count_node(graph_module, exir_ops.edge.aten.cat.default), 1) |
| 747 | + |
| 748 | + def test_remove_cat_from_slice_copy_zero_range(self) -> None: |
| 749 | + class M(torch.nn.Module): |
| 750 | + def __init__(self): |
| 751 | + super().__init__() |
| 752 | + |
| 753 | + def forward(self, x, y): |
| 754 | + x1 = torch.cat((x, y), 0) # (2, 4) |
| 755 | + return torch.slice_copy(x1, dim=0, start=0, end=0) |
| 756 | + |
| 757 | + inputs = tuple(torch.randn(2, 4) for _ in range(2)) |
| 758 | + graph_module = export_to_edge(M(), inputs).exported_program().graph_module |
| 759 | + p = RemoveCatFromSliceCopyPass() |
| 760 | + graph_module = cast(PassResult, p(graph_module)).graph_module |
| 761 | + |
| 762 | + # Ensure both cat nodes were removed |
| 763 | + self.assertEqual(count_node(graph_module, exir_ops.edge.aten.cat.default), 0) |
0 commit comments