@@ -921,6 +921,37 @@ TEST(Converters, ATenIndexTensorNoneIdx0Idx1ConvertsCorrectly) {
921
921
torch_tensorrt::tests::util::almostEqual (jit_results[0 ], trt_results[0 ].reshape_as (jit_results[0 ]), 2e-6 ));
922
922
}
923
923
924
+ TEST (Converters, ATenIndexTensorIdxsNoneConvertsCorrectly) {
925
+ const auto graph = R"IR(
926
+ graph(%x.1 : Tensor,
927
+ %index0 : Tensor,
928
+ %index1 : Tensor,
929
+ %index2 : Tensor):
930
+ %5 : NoneType = prim::Constant()
931
+ %18 : Tensor?[] = prim::ListConstruct(%index0, %index1, %index2, %5)
932
+ %19 : Tensor = aten::index(%x.1, %18)
933
+ return (%19))IR" ;
934
+
935
+ auto g = std::make_shared<torch::jit::Graph>();
936
+ torch::jit::parseIR (graph, g.get ());
937
+
938
+ auto in1 = at::randint (1 , 10 , {4 , 8 , 8 , 4 }, {at::kCUDA });
939
+ auto index0 = at::full ({4 , 13 , 1 }, 1 , {at::kCUDA }).to (torch::kLong );
940
+ auto index1 = at::full ({4 , 13 , 1 }, 2 , {at::kCUDA }).to (torch::kLong );
941
+ auto index2 = at::full ({4 , 13 , 1 }, 3 , {at::kCUDA }).to (torch::kLong );
942
+ auto index0_trt = index0.to (torch::kInt32 );
943
+ auto index1_trt = index1.to (torch::kInt32 );
944
+ auto index2_trt = index2.to (torch::kInt32 );
945
+
946
+ auto params = torch_tensorrt::core::ir::get_static_params (g->inputs (), {});
947
+ auto jit_results = torch_tensorrt::tests::util::RunGraph (g, params, {in1, index0, index1, index2});
948
+
949
+ params = torch_tensorrt::core::ir::get_static_params (g->inputs (), {});
950
+ auto trt_results = torch_tensorrt::tests::util::RunGraphEngine (g, params, {in1, index0_trt, index1_trt, index2_trt});
951
+
952
+ ASSERT_TRUE (torch_tensorrt::tests::util::almostEqual (jit_results[0 ], trt_results[0 ], 2e-6 ));
953
+ }
954
+
924
955
TEST (Converters, ATenUnbindConvertsCorrectly) {
925
956
const auto graph = R"IR(
926
957
graph(%x.1 : Tensor):
0 commit comments