diff --git a/core/conversion/converters/impl/matrix_multiply.cpp b/core/conversion/converters/impl/matrix_multiply.cpp index c4b12da810..90772ea8c4 100644 --- a/core/conversion/converters/impl/matrix_multiply.cpp +++ b/core/conversion/converters/impl/matrix_multiply.cpp @@ -16,12 +16,28 @@ auto mm_registrations TORCHTRT_UNUSED = [](ConversionCtx* ctx, const torch::jit::Node* n, args& args) -> bool { auto self = args[0].ITensorOrFreeze(ctx); auto other = args[1].ITensorOrFreeze(ctx); + + auto selfDims = self->getDimensions().nbDims; + auto otherDims = other->getDimensions().nbDims; + + bool squeezeFront = false; + bool squeezeBack = false; + + if (selfDims == 1 && selfDims < otherDims) { + squeezeFront = true; + } else if (otherDims == 1 && otherDims < selfDims) { + // Append a 1 to the end of the shape before padding front to match self + other = addPadding(ctx, n, other, 2, true, false); + otherDims = other->getDimensions().nbDims; + squeezeBack = true; + } + // Ensure self and other tensors have same nbDims by expanding the dimensions (from 0 axis) if // necessary. - if (self->getDimensions().nbDims < other->getDimensions().nbDims) { - self = addPadding(ctx, n, self, other->getDimensions().nbDims, false, false); - } else { - other = addPadding(ctx, n, other, self->getDimensions().nbDims, false, false); + if (selfDims < otherDims) { + self = addPadding(ctx, n, self, otherDims, false, false); + } else if (otherDims < selfDims) { + other = addPadding(ctx, n, other, selfDims, false, false); } auto mm_layer = ctx->net->addMatrixMultiply( @@ -29,7 +45,20 @@ auto mm_registrations TORCHTRT_UNUSED = TORCHTRT_CHECK(mm_layer, "Unable to create matrix multiplication node: " << *n); mm_layer->setName(util::node_info(n).c_str()); - auto out_tensor = ctx->AssociateValueAndTensor(n->outputs()[0], mm_layer->getOutput(0)); + auto out = mm_layer->getOutput(0); + + if (squeezeFront || squeezeBack) { + auto squeezeDimOffset = squeezeFront ? 2 : 1; + auto reshapeDims = + util::squeezeDims(out->getDimensions(), out->getDimensions().nbDims - squeezeDimOffset); + auto shuffle_layer = ctx->net->addShuffle(*out); + LOG_DEBUG("Squeezing matmul output for 1d correction: " << reshapeDims); + TORCHTRT_CHECK(shuffle_layer, "Unable to create shuffle layer from node: " << *n); + shuffle_layer->setReshapeDimensions(reshapeDims); + shuffle_layer->setName((util::node_info(n) + "_squeeze").c_str()); + out = shuffle_layer->getOutput(0); + } + auto out_tensor = ctx->AssociateValueAndTensor(n->outputs()[0], out); LOG_DEBUG("Output tensor shape: " << out_tensor->getDimensions()); return true; diff --git a/tests/core/conversion/converters/test_matrix_multiply.cpp b/tests/core/conversion/converters/test_matrix_multiply.cpp index 9c84ba22f6..50248f379a 100644 --- a/tests/core/conversion/converters/test_matrix_multiply.cpp +++ b/tests/core/conversion/converters/test_matrix_multiply.cpp @@ -21,9 +21,8 @@ TEST(Converters, ATenMMConvertsCorrectly) { params = torch_tensorrt::core::ir::get_static_params(g->inputs(), {}); auto trt_results = torch_tensorrt::tests::util::RunGraphEngine(g, params, {in1, in2}); - auto trt = trt_results[0].reshape_as(jit_results[0]); - ASSERT_TRUE(torch_tensorrt::tests::util::almostEqual(jit_results[0], trt, 2e-6)); + ASSERT_TRUE(torch_tensorrt::tests::util::almostEqual(jit_results[0], trt_results[0], 2e-6)); } TEST(Converters, ATenMMWithDiffShapesConvertsCorrectly) { @@ -42,9 +41,131 @@ TEST(Converters, ATenMMWithDiffShapesConvertsCorrectly) { params = torch_tensorrt::core::ir::get_static_params(g->inputs(), {}); auto trt_results = torch_tensorrt::tests::util::RunGraphEngine(g, params, {in1, in2}); - auto trt = trt_results[0].reshape_as(jit_results[0]); - ASSERT_TRUE(torch_tensorrt::tests::util::almostEqual(jit_results[0], trt, 2e-6)); + ASSERT_TRUE(torch_tensorrt::tests::util::almostEqual(jit_results[0], trt_results[0], 2e-6)); +} + +TEST(Converters, ATenMM1d2dConvertsCorrectly) { + const auto graph = R"IR( + graph(%0 : Tensor, %1 : Tensor): + %2 : Tensor = aten::matmul(%0, %1) + return (%2))IR"; + + auto g = std::make_shared(); + torch::jit::parseIR(graph, g.get()); + + auto in1 = at::randint(0, 5, {10}, {at::kCUDA}); + auto in2 = at::randint(0, 5, {10, 1}, {at::kCUDA}); + auto params = torch_tensorrt::core::ir::get_static_params(g->inputs(), {}); + auto jit_results = torch_tensorrt::tests::util::RunGraph(g, params, {in1, in2}); + + params = torch_tensorrt::core::ir::get_static_params(g->inputs(), {}); + auto trt_results = torch_tensorrt::tests::util::RunGraphEngine(g, params, {in1, in2}); + + ASSERT_TRUE(torch_tensorrt::tests::util::almostEqual(jit_results[0], trt_results[0], 2e-6)); +} + +TEST(Converters, ATenMM1d3dConvertsCorrectly) { + const auto graph = R"IR( + graph(%0 : Tensor, %1 : Tensor): + %2 : Tensor = aten::matmul(%0, %1) + return (%2))IR"; + + auto g = std::make_shared(); + torch::jit::parseIR(graph, g.get()); + + auto in1 = at::randint(0, 5, {10}, {at::kCUDA}); + auto in2 = at::randint(0, 5, {2, 10, 8}, {at::kCUDA}); + auto params = torch_tensorrt::core::ir::get_static_params(g->inputs(), {}); + auto jit_results = torch_tensorrt::tests::util::RunGraph(g, params, {in1, in2}); + + params = torch_tensorrt::core::ir::get_static_params(g->inputs(), {}); + auto trt_results = torch_tensorrt::tests::util::RunGraphEngine(g, params, {in1, in2}); + + ASSERT_TRUE(torch_tensorrt::tests::util::almostEqual(jit_results[0], trt_results[0], 2e-6)); +} + +TEST(Converters, ATenMM1d4dConvertsCorrectly) { + const auto graph = R"IR( + graph(%0 : Tensor, %1 : Tensor): + %2 : Tensor = aten::matmul(%0, %1) + return (%2))IR"; + + auto g = std::make_shared(); + torch::jit::parseIR(graph, g.get()); + + auto in1 = at::randint(0, 5, {10}, {at::kCUDA}); + auto in2 = at::randint(0, 5, {2, 3, 10, 8}, {at::kCUDA}); + auto params = torch_tensorrt::core::ir::get_static_params(g->inputs(), {}); + auto jit_results = torch_tensorrt::tests::util::RunGraph(g, params, {in1, in2}); + + params = torch_tensorrt::core::ir::get_static_params(g->inputs(), {}); + auto trt_results = torch_tensorrt::tests::util::RunGraphEngine(g, params, {in1, in2}); + + ASSERT_TRUE(torch_tensorrt::tests::util::almostEqual(jit_results[0], trt_results[0], 2e-6)); +} + +TEST(Converters, ATenMM3d1dConvertsCorrectly) { + const auto graph = R"IR( + graph(%0 : Tensor, %1 : Tensor): + %2 : Tensor = aten::matmul(%0, %1) + return (%2))IR"; + + auto g = std::make_shared(); + torch::jit::parseIR(graph, g.get()); + + auto in1 = at::randint(0, 5, {2, 10, 8}, {at::kCUDA}); + auto in2 = at::randint(0, 5, {8}, {at::kCUDA}); + + auto params = torch_tensorrt::core::ir::get_static_params(g->inputs(), {}); + auto jit_results = torch_tensorrt::tests::util::RunGraph(g, params, {in1, in2}); + + params = torch_tensorrt::core::ir::get_static_params(g->inputs(), {}); + auto trt_results = torch_tensorrt::tests::util::RunGraphEngine(g, params, {in1, in2}); + + ASSERT_TRUE(torch_tensorrt::tests::util::almostEqual(jit_results[0], trt_results[0], 2e-6)); +} + +TEST(Converters, ATenMM2d1dConvertsCorrectly) { + const auto graph = R"IR( + graph(%0 : Tensor, %1 : Tensor): + %2 : Tensor = aten::matmul(%0, %1) + return (%2))IR"; + + auto g = std::make_shared(); + torch::jit::parseIR(graph, g.get()); + + auto in1 = at::randint(0, 5, {1, 10}, {at::kCUDA}); + auto in2 = at::randint(0, 5, {10}, {at::kCUDA}); + + auto params = torch_tensorrt::core::ir::get_static_params(g->inputs(), {}); + auto jit_results = torch_tensorrt::tests::util::RunGraph(g, params, {in1, in2}); + + params = torch_tensorrt::core::ir::get_static_params(g->inputs(), {}); + auto trt_results = torch_tensorrt::tests::util::RunGraphEngine(g, params, {in1, in2}); + + ASSERT_TRUE(torch_tensorrt::tests::util::almostEqual(jit_results[0], trt_results[0], 2e-6)); +} + +TEST(Converters, ATenMM4d1dConvertsCorrectly) { + const auto graph = R"IR( + graph(%0 : Tensor, %1 : Tensor): + %2 : Tensor = aten::matmul(%0, %1) + return (%2))IR"; + + auto g = std::make_shared(); + torch::jit::parseIR(graph, g.get()); + + auto in1 = at::randint(0, 5, {2, 3, 10, 8}, {at::kCUDA}); + auto in2 = at::randint(0, 5, {8}, {at::kCUDA}); + + auto params = torch_tensorrt::core::ir::get_static_params(g->inputs(), {}); + auto jit_results = torch_tensorrt::tests::util::RunGraph(g, params, {in1, in2}); + + params = torch_tensorrt::core::ir::get_static_params(g->inputs(), {}); + auto trt_results = torch_tensorrt::tests::util::RunGraphEngine(g, params, {in1, in2}); + + ASSERT_TRUE(torch_tensorrt::tests::util::almostEqual(jit_results[0], trt_results[0], 2e-6)); } TEST(Converters, ATenBMMConvertsCorrectly) { @@ -63,9 +184,8 @@ TEST(Converters, ATenBMMConvertsCorrectly) { params = torch_tensorrt::core::ir::get_static_params(g->inputs(), {}); auto trt_results = torch_tensorrt::tests::util::RunGraphEngine(g, params, {in1, in2}); - auto trt = trt_results[0].reshape_as(jit_results[0]); - ASSERT_TRUE(torch_tensorrt::tests::util::almostEqual(jit_results[0], trt, 2e-6)); + ASSERT_TRUE(torch_tensorrt::tests::util::almostEqual(jit_results[0], trt_results[0], 2e-6)); } TEST(Converters, ATenBADDBMMConvertsCorrectly) {