Skip to content

Commit d69651c

Browse files
authored
Fix bug: correct the output shape of aten::index.Tensor (#1314)
* support multiple indices for aten::index.Tensor Signed-off-by: Ruoqian Guo <[email protected]> * fix: correct output shape of aten::index.Tensor Signed-off-by: Ruoqian Guo <[email protected]> Signed-off-by: Ruoqian Guo <[email protected]>
1 parent ce67ceb commit d69651c

File tree

2 files changed

+33
-1
lines changed

2 files changed

+33
-1
lines changed

core/conversion/converters/impl/select.cpp

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -267,6 +267,7 @@ auto select_registrations TORCHTRT_UNUSED =
267267
.pattern(
268268
{"aten::index.Tensor(Tensor self, Tensor?[] indices) -> (Tensor)",
269269
[](ConversionCtx* ctx, const torch::jit::Node* n, args& args) -> bool {
270+
// refer to https://github.com/pytorch/pytorch/blob/master/torch/onnx/symbolic_opset9.py#L4627
270271
auto in = args[0].ITensorOrFreeze(ctx);
271272
auto ts = args[1].IValue()->toListRef();
272273

@@ -471,7 +472,7 @@ auto select_registrations TORCHTRT_UNUSED =
471472
}
472473
}
473474
auto concat_final_shape_layer =
474-
ctx->net->addConcatenation(concat_tensors.data(), concat_tensors.size());
475+
ctx->net->addConcatenation(concat_final_tensors.data(), concat_final_tensors.size());
475476
auto unfold_advanced_shuffle_layer = ctx->net->addShuffle(*shuffle_out);
476477
unfold_advanced_shuffle_layer->setInput(1, *concat_final_shape_layer->getOutput(0));
477478
reshape_output = unfold_advanced_shuffle_layer->getOutput(0);

tests/core/conversion/converters/test_select.cpp

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -921,6 +921,37 @@ TEST(Converters, ATenIndexTensorNoneIdx0Idx1ConvertsCorrectly) {
921921
torch_tensorrt::tests::util::almostEqual(jit_results[0], trt_results[0].reshape_as(jit_results[0]), 2e-6));
922922
}
923923

924+
TEST(Converters, ATenIndexTensorIdxsNoneConvertsCorrectly) {
925+
const auto graph = R"IR(
926+
graph(%x.1 : Tensor,
927+
%index0 : Tensor,
928+
%index1 : Tensor,
929+
%index2 : Tensor):
930+
%5 : NoneType = prim::Constant()
931+
%18 : Tensor?[] = prim::ListConstruct(%index0, %index1, %index2, %5)
932+
%19 : Tensor = aten::index(%x.1, %18)
933+
return (%19))IR";
934+
935+
auto g = std::make_shared<torch::jit::Graph>();
936+
torch::jit::parseIR(graph, g.get());
937+
938+
auto in1 = at::randint(1, 10, {4, 8, 8, 4}, {at::kCUDA});
939+
auto index0 = at::full({4, 13, 1}, 1, {at::kCUDA}).to(torch::kLong);
940+
auto index1 = at::full({4, 13, 1}, 2, {at::kCUDA}).to(torch::kLong);
941+
auto index2 = at::full({4, 13, 1}, 3, {at::kCUDA}).to(torch::kLong);
942+
auto index0_trt = index0.to(torch::kInt32);
943+
auto index1_trt = index1.to(torch::kInt32);
944+
auto index2_trt = index2.to(torch::kInt32);
945+
946+
auto params = torch_tensorrt::core::ir::get_static_params(g->inputs(), {});
947+
auto jit_results = torch_tensorrt::tests::util::RunGraph(g, params, {in1, index0, index1, index2});
948+
949+
params = torch_tensorrt::core::ir::get_static_params(g->inputs(), {});
950+
auto trt_results = torch_tensorrt::tests::util::RunGraphEngine(g, params, {in1, index0_trt, index1_trt, index2_trt});
951+
952+
ASSERT_TRUE(torch_tensorrt::tests::util::almostEqual(jit_results[0], trt_results[0], 2e-6));
953+
}
954+
924955
TEST(Converters, ATenUnbindConvertsCorrectly) {
925956
const auto graph = R"IR(
926957
graph(%x.1 : Tensor):

0 commit comments

Comments
 (0)