diff --git a/core/conversion/converters/impl/shuffle.cpp b/core/conversion/converters/impl/shuffle.cpp index 5a8e992d90..8d25525a6e 100644 --- a/core/conversion/converters/impl/shuffle.cpp +++ b/core/conversion/converters/impl/shuffle.cpp @@ -20,7 +20,12 @@ static auto shuffle_registrations TORCHTRT_UNUSED = auto in_shape = util::toVec(in->getDimensions()); std::vector out_shape; if (ctx->input_is_dynamic) { - end_dim = (end_dim == -1) ? in_shape.size() - 1 : end_dim; + if (start_dim < 0) { + start_dim = start_dim + in_shape.size(); + } + if (end_dim < 0) { + end_dim = end_dim + in_shape.size(); + } int nbDynamicFlattenedDims = 0; int nbDynamicUnflattenedDims = 0; for (int i = 0; i < (int)in_shape.size(); i++) { diff --git a/tests/core/conversion/converters/test_shuffle.cpp b/tests/core/conversion/converters/test_shuffle.cpp index 9c972ba988..a4d22e10c6 100644 --- a/tests/core/conversion/converters/test_shuffle.cpp +++ b/tests/core/conversion/converters/test_shuffle.cpp @@ -4,7 +4,6 @@ #include "tests/util/util.h" #include "torch/csrc/jit/ir/irparser.h" -// TODO: IR Parser doesnt work well with neg numbers TEST(Converters, ATenFlattenConvertsCorrectly) { const auto graph = R"IR( graph(%0 : Tensor): @@ -23,12 +22,32 @@ TEST(Converters, ATenFlattenConvertsCorrectly) { in = at::clone(in); params = torch_tensorrt::core::ir::get_static_params(g->inputs(), {}); auto trt_results = torch_tensorrt::tests::util::RunGraphEngine(g, params, {in}); - auto trt = trt_results[0].reshape_as(jit_results[0]); - ASSERT_TRUE(torch_tensorrt::tests::util::almostEqual(jit_results[0], trt, 2e-6)); + ASSERT_TRUE(torch_tensorrt::tests::util::almostEqual(jit_results[0], trt_results[0], 2e-6)); +} + +TEST(Converters, ATenFlattenNegDimsConvertsCorrectly) { + const auto graph = R"IR( + graph(%0 : Tensor): + %1 : int = prim::Constant[value=-3]() + %2 : int = prim::Constant[value=-2]() + %3 : Tensor = aten::flatten(%0, %1, %2) + return (%3))IR"; + + auto g = std::make_shared(); + torch::jit::parseIR(graph, g.get()); + + auto in = at::randint(0, 5, {2, 3, 3}, {at::kCUDA}); + auto params = torch_tensorrt::core::ir::get_static_params(g->inputs(), {}); + auto jit_results = torch_tensorrt::tests::util::RunGraph(g, params, {in}); + + in = at::clone(in); + params = torch_tensorrt::core::ir::get_static_params(g->inputs(), {}); + auto trt_results = torch_tensorrt::tests::util::RunGraphEngine(g, params, {in}); + + ASSERT_TRUE(torch_tensorrt::tests::util::almostEqual(jit_results[0], trt_results[0], 2e-6)); } -// TODO: IR Parser doesnt work well with neg numbers TEST(Converters, ATenFlattenOtherDimsConvertsCorrectly) { const auto graph = R"IR( graph(%0 : Tensor): @@ -47,9 +66,8 @@ TEST(Converters, ATenFlattenOtherDimsConvertsCorrectly) { in = at::clone(in); params = torch_tensorrt::core::ir::get_static_params(g->inputs(), {}); auto trt_results = torch_tensorrt::tests::util::RunGraphEngine(g, params, {in}); - auto trt = trt_results[0].reshape_as(jit_results[0]); - ASSERT_TRUE(torch_tensorrt::tests::util::almostEqual(jit_results[0], trt, 2e-6)); + ASSERT_TRUE(torch_tensorrt::tests::util::almostEqual(jit_results[0], trt_results[0], 2e-6)); } TEST(Converters, ATenReshapeConvertsCorrectly) { @@ -215,6 +233,29 @@ TEST(Converters, ATenFlattenConvertsCorrectlyWithDynamicBatch) { ASSERT_TRUE(torch_tensorrt::tests::util::almostEqual(jit_results[0], trt, 2e-6)); } +TEST(Converters, ATenFlattenNegDimsConvertsCorrectlyWithDynamicBatch) { + const auto graph = R"IR( + graph(%0 : Tensor): + %1 : int = prim::Constant[value=-3]() + %2 : int = prim::Constant[value=-2]() + %3 : Tensor = aten::flatten(%0, %1, %2) + return (%3))IR"; + + auto g = std::make_shared(); + torch::jit::parseIR(graph, g.get()); + + auto in = at::randint(0, 5, {2, 3, 4}, {at::kCUDA}); + auto params = torch_tensorrt::core::ir::get_static_params(g->inputs(), {}); + auto jit_results = torch_tensorrt::tests::util::RunGraph(g, params, {in}); + + in = at::clone(in); + params = torch_tensorrt::core::ir::get_static_params(g->inputs(), {}); + auto trt_results = torch_tensorrt::tests::util::RunGraphEngineDynamic(g, params, {in}, true); + auto trt = trt_results[0].reshape_as(jit_results[0]); + + ASSERT_TRUE(torch_tensorrt::tests::util::almostEqual(jit_results[0], trt, 2e-6)); +} + TEST(Converters, ATenTransposeConvertsCorrectly) { const auto graph = R"IR( graph(%x.1 : Tensor):