5
5
from torch .testing ._internal .common_utils import run_tests
6
6
from torch_tensorrt import Input
7
7
8
- from . harness import DispatchTestCase
8
+ from harness import DispatchTestCase
9
9
10
10
11
11
class TestIndexConverter (DispatchTestCase ):
@@ -26,6 +26,21 @@ def forward(self, x):
26
26
TestModule (),
27
27
input ,
28
28
)
29
+
30
+ def test_index_zero_two_dim_ITensor (self ):
31
+ class TestModule (nn .Module ):
32
+ def forward (self , x , index0 ):
33
+ indices = [None , index0 ]
34
+ out = torch .ops .aten .index .Tensor (x , indices )
35
+ return out
36
+
37
+ input = torch .randn (2 , 2 )
38
+ index0 = torch .randint (0 , 1 , (1 , 1 ))
39
+ index0 = index0 .to (torch .int32 )
40
+ self .run_test (
41
+ TestModule (),
42
+ [input , index0 ],
43
+ )
29
44
30
45
def test_index_zero_index_three_dim (self ):
31
46
class TestModule (nn .Module ):
@@ -43,6 +58,21 @@ def forward(self, x):
43
58
TestModule (),
44
59
input ,
45
60
)
61
+
62
+ def test_index_zero_index_three_dim_ITensor (self ):
63
+ class TestModule (nn .Module ):
64
+ def forward (self , x , index0 ):
65
+ indices = [None , index0 , None ]
66
+ out = torch .ops .aten .index .Tensor (x , indices )
67
+ return out
68
+
69
+ input = torch .randn (2 , 2 , 2 )
70
+ index0 = torch .randint (0 , 1 , (1 , 1 ))
71
+ index0 = index0 .to (torch .int32 )
72
+ self .run_test (
73
+ TestModule (),
74
+ [input , index0 ]
75
+ )
46
76
47
77
def test_index_zero_index_one_index_two_three_dim (self ):
48
78
class TestModule (nn .Module ):
0 commit comments