Skip to content

Commit fb91a75

Browse files
committed
Index ITensor test
1 parent 563ca81 commit fb91a75

File tree

1 file changed

+31
-1
lines changed

1 file changed

+31
-1
lines changed

tests/py/dynamo/conversion/test_index_aten.py

Lines changed: 31 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
from torch.testing._internal.common_utils import run_tests
66
from torch_tensorrt import Input
77

8-
from .harness import DispatchTestCase
8+
from harness import DispatchTestCase
99

1010

1111
class TestIndexConverter(DispatchTestCase):
@@ -26,6 +26,21 @@ def forward(self, x):
2626
TestModule(),
2727
input,
2828
)
29+
30+
def test_index_zero_two_dim_ITensor(self):
31+
class TestModule(nn.Module):
32+
def forward(self, x, index0):
33+
indices = [None, index0]
34+
out = torch.ops.aten.index.Tensor(x, indices)
35+
return out
36+
37+
input = torch.randn(2, 2)
38+
index0 = torch.randint(0, 1, (1, 1))
39+
index0 = index0.to(torch.int32)
40+
self.run_test(
41+
TestModule(),
42+
[input, index0],
43+
)
2944

3045
def test_index_zero_index_three_dim(self):
3146
class TestModule(nn.Module):
@@ -43,6 +58,21 @@ def forward(self, x):
4358
TestModule(),
4459
input,
4560
)
61+
62+
def test_index_zero_index_three_dim_ITensor(self):
63+
class TestModule(nn.Module):
64+
def forward(self, x, index0):
65+
indices = [None, index0, None]
66+
out = torch.ops.aten.index.Tensor(x, indices)
67+
return out
68+
69+
input = torch.randn(2, 2, 2)
70+
index0 = torch.randint(0, 1, (1, 1))
71+
index0 = index0.to(torch.int32)
72+
self.run_test(
73+
TestModule(),
74+
[input, index0]
75+
)
4676

4777
def test_index_zero_index_one_index_two_three_dim(self):
4878
class TestModule(nn.Module):

0 commit comments

Comments
 (0)