|
2 | 2 | import torch_tensorrt |
3 | 3 | from parameterized import parameterized |
4 | 4 | from torch.testing._internal.common_utils import TestCase, run_tests |
| 5 | +from parameterized import parameterized |
5 | 6 |
|
6 | 7 | from ..testing_utilities import DECIMALS_OF_AGREEMENT, lower_graph_testing |
7 | 8 |
|
@@ -963,5 +964,60 @@ def forward(self, input): |
963 | 964 | ) |
964 | 965 |
|
965 | 966 |
|
| 967 | +class TestScatterAdd(TestCase): |
| 968 | + @parameterized.expand( |
| 969 | + [ |
| 970 | + ( |
| 971 | + "scatter_add_zero_dim_indexOne_constant", |
| 972 | + 0, |
| 973 | + torch.tensor([[0, 1, 2, 0]]), |
| 974 | + torch.tensor([[1, 2, 3, 4]], dtype=torch.int32), |
| 975 | + ), |
| 976 | + ( |
| 977 | + "scatter_add_zero_dim_indexTwo_constant", |
| 978 | + 0, |
| 979 | + torch.tensor([[0, 1, 2, 0], [1, 2, 1, 1]]), |
| 980 | + torch.tensor([[1, 2, 3, 4], [5, 6, 7, 8]], dtype=torch.int32), |
| 981 | + ), |
| 982 | + ( |
| 983 | + "scatter_add_one_dim_indexOne_constant", |
| 984 | + 1, |
| 985 | + torch.tensor([[0, 1, 2, 0]]), |
| 986 | + torch.tensor([[1, 2, 3, 1]], dtype=torch.int32), |
| 987 | + ), |
| 988 | + ( |
| 989 | + "scatter_add_one_dim_indexTwo_costant", |
| 990 | + 1, |
| 991 | + torch.tensor([[0, 1, 2, 0], [1, 2, 1, 1]]), |
| 992 | + torch.tensor([[1, 2, 3, 1], [5, 6, 5, 5]], dtype=torch.int32), |
| 993 | + ), |
| 994 | + ] |
| 995 | + ) |
| 996 | + def test_scatter_add(self, _, dim, index, src): |
| 997 | + class TestModule(torch.nn.Module): |
| 998 | + def __init__(self): |
| 999 | + super().__init__() |
| 1000 | + |
| 1001 | + def forward(self, input): |
| 1002 | + return torch.ops.aten.scatter_add.default(input, dim, index, src) |
| 1003 | + |
| 1004 | + # Operations expected to be included in the traced graph after decompositions |
| 1005 | + expected_ops = {torch.ops.aten.scatter.src} |
| 1006 | + |
| 1007 | + input = torch.zeros(3, 5, dtype=torch.int32) |
| 1008 | + inputs = [input] |
| 1009 | + |
| 1010 | + fx_graph = torch.fx.symbolic_trace(TestModule()) |
| 1011 | + _, expected_ops_unseen = lower_graph_testing( |
| 1012 | + fx_graph, inputs, expected_ops=expected_ops, min_block_size=2 |
| 1013 | + ) |
| 1014 | + |
| 1015 | + self.assertEquals( |
| 1016 | + len(expected_ops_unseen), |
| 1017 | + 0, |
| 1018 | + f"The following expected ops were not encountered: {expected_ops_unseen}", |
| 1019 | + ) |
| 1020 | + |
| 1021 | + |
966 | 1022 | if __name__ == "__main__": |
967 | 1023 | run_tests() |
0 commit comments