Skip to content

Commit 10a384e

Browse files
committed
scatter_add_decomposition
Fixing scatter_add test cases. To do: fix the index collision cases Index collision cases Index collision cases- removing the torch.unique checl
1 parent abed8f0 commit 10a384e

File tree

2 files changed

+87
-0
lines changed

2 files changed

+87
-0
lines changed

py/torch_tensorrt/dynamo/lowering/_decompositions.py

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -243,6 +243,37 @@ def empty_strided_decomposition(*args, **kwargs) -> torch.Tensor:
243243
)
244244

245245

246+
@register_torch_trt_decomposition(
247+
torch.ops.aten.scatter_add.default, registry=TORCH_TRT_DECOMPOSITIONS
248+
)
249+
def scatter_add_decomposition(
250+
input_tensor: torch.Tensor,
251+
src_tensor: torch.Tensor,
252+
dim: int,
253+
index: torch.Tensor,
254+
) -> torch.Tensor:
255+
scatter_add_tensor = input_tensor
256+
src_copy = src_tensor
257+
src_shape = list(src_tensor.shape)
258+
del src_shape[dim]
259+
select_src_dim = src_copy.shape[dim]
260+
to_stack_dummy_src = tuple(torch.empty(src_shape) for _ in range(select_src_dim))
261+
for index_src_dim in range(0, select_src_dim, 1):
262+
select_tensor_dim = torch.select(src_copy, dim, index_src_dim)
263+
to_stack_src = to_stack_dummy_src
264+
if(index_src_dim == 0):
265+
to_stack_src = (select_tensor_dim.cpu(),) + to_stack_dummy_src[index_src_dim+1:]
266+
elif(index_src_dim == select_src_dim - 1 ):
267+
to_stack_src = to_stack_dummy_src[:index_src_dim] + (select_tensor_dim.cpu(),)
268+
else:
269+
to_stack_src = to_stack_dummy_src[:index_src_dim] + (select_tensor_dim.cpu(),) + to_stack_dummy_src[index_src_dim+1:]
270+
271+
stacked_src = torch.stack(to_stack_src, dim)
272+
input_tensor_to_add = torch.scatter(torch.empty_like(input_tensor, dtype= torch.float32), dim, index, stacked_src.cuda())
273+
scatter_add_tensor = torch.add(scatter_add_tensor, input_tensor_to_add)
274+
return scatter_add_tensor
275+
276+
246277
def get_decompositions(
247278
enable_experimental_decompositions: bool = False,
248279
) -> Dict[OpOverload, Callable[[Any], Any]]:

tests/py/dynamo/lowering/test_decompositions.py

Lines changed: 56 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
import torch_tensorrt
33
from parameterized import parameterized
44
from torch.testing._internal.common_utils import TestCase, run_tests
5+
from parameterized import parameterized
56

67
from ..testing_utilities import DECIMALS_OF_AGREEMENT, lower_graph_testing
78

@@ -963,5 +964,60 @@ def forward(self, input):
963964
)
964965

965966

967+
class TestScatterAdd(TestCase):
968+
@parameterized.expand(
969+
[
970+
(
971+
"scatter_add_zero_dim_indexOne_constant",
972+
0,
973+
torch.tensor([[0, 1, 2, 0]]),
974+
torch.tensor([[1, 2, 3, 4]], dtype=torch.int32),
975+
),
976+
(
977+
"scatter_add_zero_dim_indexTwo_constant",
978+
0,
979+
torch.tensor([[0, 1, 2, 0], [1, 2, 1, 1]]),
980+
torch.tensor([[1, 2, 3, 4], [5, 6, 7, 8]], dtype=torch.int32),
981+
),
982+
(
983+
"scatter_add_one_dim_indexOne_constant",
984+
1,
985+
torch.tensor([[0, 1, 2, 0]]),
986+
torch.tensor([[1, 2, 3, 1]], dtype=torch.int32),
987+
),
988+
(
989+
"scatter_add_one_dim_indexTwo_costant",
990+
1,
991+
torch.tensor([[0, 1, 2, 0], [1, 2, 1, 1]]),
992+
torch.tensor([[1, 2, 3, 1], [5, 6, 5, 5]], dtype=torch.int32),
993+
),
994+
]
995+
)
996+
def test_scatter_add(self, _, dim, index, src):
997+
class TestModule(torch.nn.Module):
998+
def __init__(self):
999+
super().__init__()
1000+
1001+
def forward(self, input):
1002+
return torch.ops.aten.scatter_add.default(input, dim, index, src)
1003+
1004+
# Operations expected to be included in the traced graph after decompositions
1005+
expected_ops = {torch.ops.aten.scatter.src}
1006+
1007+
input = torch.zeros(3, 5, dtype=torch.int32)
1008+
inputs = [input]
1009+
1010+
fx_graph = torch.fx.symbolic_trace(TestModule())
1011+
_, expected_ops_unseen = lower_graph_testing(
1012+
fx_graph, inputs, expected_ops=expected_ops, min_block_size=2
1013+
)
1014+
1015+
self.assertEquals(
1016+
len(expected_ops_unseen),
1017+
0,
1018+
f"The following expected ops were not encountered: {expected_ops_unseen}",
1019+
)
1020+
1021+
9661022
if __name__ == "__main__":
9671023
run_tests()

0 commit comments

Comments
 (0)