diff --git a/core/conversion/converters/impl/layer_norm.cpp b/core/conversion/converters/impl/layer_norm.cpp index 0c00ee2c4d..5bc4f1a07e 100644 --- a/core/conversion/converters/impl/layer_norm.cpp +++ b/core/conversion/converters/impl/layer_norm.cpp @@ -10,138 +10,84 @@ namespace converters { namespace impl { namespace { +nvinfer1::ITensor* broadcast( + ConversionCtx* ctx, + const torch::jit::Node* n, + nvinfer1::ITensor* to_broadcast, + const int nbDims, + const std::string& tag) { + auto to_broadcast_nbdims = to_broadcast->getDimensions().nbDims; + TORCHTRT_CHECK(to_broadcast_nbdims <= nbDims, "Cannot broadcast tensor with more dimensions than the target"); + if (to_broadcast_nbdims == nbDims) { + return to_broadcast; + } + auto shape_layer = ctx->net->addShape(*to_broadcast); + TORCHTRT_CHECK(shape_layer, "Unable to create shape layer from node: " << *n); + shape_layer->setName((util::node_info(n) + "_shape_" + tag).c_str()); + auto shape_layer_out = shape_layer->getOutput(0); + + auto extra_dims_tensor = torch::ones({nbDims - to_broadcast_nbdims}, torch::TensorOptions().dtype(torch::kInt32)); + auto extra_dims_itensor = tensor_to_const(ctx, extra_dims_tensor); + + std::vector to_concat = {extra_dims_itensor, shape_layer_out}; + auto concat_layer = ctx->net->addConcatenation(to_concat.data(), to_concat.size()); + TORCHTRT_CHECK(concat_layer, "Unable to create concat layer from node: " << *n); + concat_layer->setName((util::node_info(n) + "_concat_" + tag).c_str()); + auto target_shape = concat_layer->getOutput(0); + + auto shuffle_layer = ctx->net->addShuffle(*to_broadcast); + TORCHTRT_CHECK(shuffle_layer, "Unable to create shuffle layer from node: " << *n); + shuffle_layer->setName((util::node_info(n) + "_shuffle_" + tag).c_str()); + shuffle_layer->setInput(1, *target_shape); + auto output = shuffle_layer->getOutput(0); + LOG_DEBUG( + "Broadcast " << tag << " to shape: " << output->getDimensions() << " from " << to_broadcast->getDimensions()); + return output; +} + auto layer_norm_registrations TORCHTRT_UNUSED = RegisterNodeConversionPatterns().pattern({ R"SIG(aten::layer_norm(Tensor input, int[] normalized_shape, Tensor? gamma, Tensor? beta, float eps, bool cudnn_enabled) -> (Tensor))SIG", [](ConversionCtx* ctx, const torch::jit::Node* n, args& args) -> bool { - auto input = args[0].ITensor(); // assumes non-static input Tensor - auto orig_shape = input->getDimensions(); - auto shape = util::toVec(orig_shape); - - /* Layer_Norm normalizes over last N dimensions. - normalizaed_shape could be (C,H,W), (H,W), or (W). */ - // This could be an IntList or ITensorList. We only need the size of this list. - auto normalized_shape = args[1].IValue()->toList(); - - // Unwrap eps. - auto eps = args[4].unwrapToDouble(); - - LOG_DEBUG("cudnn disregarded"); - - // Set up axis_ask for E[x]. - uint32_t axis_mask = 0; - for (size_t i = 0; i < normalized_shape.size(); i++) { - axis_mask |= 1 << (shape.size() - i - 1); + auto input = args[0].ITensorOrFreeze(ctx); + auto input_shape = input->getDimensions(); + auto input_shape_vec = util::toVec(input_shape); + auto normalized_shape = args[1].unwrapToIntList(); + auto normalized_shape_vec = util::toVec(util::toDims(normalized_shape)); + auto axis = input_shape_vec.size() - normalized_shape_vec.size(); + uint32_t axes_mask = 0; + for (size_t i = axis; i < input_shape_vec.size(); i++) { + axes_mask |= 1 << i; } - LOG_DEBUG("Axis Mask for E[x]" << std::bitset<32>(axis_mask)); - - // E[x] - auto mean_expected = ctx->net->addReduce(*input, nvinfer1::ReduceOperation::kAVG, axis_mask, true); - TORCHTRT_CHECK(mean_expected, "Unable to create mean_expected from node: " << *n); - mean_expected->setName((util::node_info(n) + "_mean_expected").c_str()); - auto mean_expected_out = mean_expected->getOutput(0); - - // X-E[x] - auto sub = add_elementwise( - ctx, nvinfer1::ElementWiseOperation::kSUB, input, mean_expected_out, (util::node_info(n) + "_sub").c_str()); - TORCHTRT_CHECK(sub, "Unable to create Sub layer from node: " << *n); - sub->setName((util::node_info(n) + "_sub").c_str()); - auto xsubmean_out = sub->getOutput(0); - - // Variance = mean(pow(xsubmean,2)) - float pow_scalar = 2; - auto exponent = tensor_to_const(ctx, torch::tensor({pow_scalar})); - auto pow = add_elementwise( - ctx, nvinfer1::ElementWiseOperation::kPOW, xsubmean_out, exponent, (util::node_info(n) + "_pow").c_str()); - TORCHTRT_CHECK(pow, "Unable to create Pow layer from node: " << *n); - pow->setName((util::node_info(n) + "_pow").c_str()); - auto pow_out = pow->getOutput(0); - - auto mean_var = ctx->net->addReduce(*pow_out, nvinfer1::ReduceOperation::kAVG, axis_mask, true); - TORCHTRT_CHECK(mean_var, "Unable to create mean_var from node: " << *n); - mean_var->setName((util::node_info(n) + "_mean_var").c_str()); - auto mean_var_out = mean_var->getOutput(0); - - // Variance + eps - auto eps_tensor = tensor_to_const(ctx, torch::tensor({eps})); - auto add = add_elementwise( - ctx, nvinfer1::ElementWiseOperation::kSUM, mean_var_out, eps_tensor, (util::node_info(n) + "_add").c_str()); - TORCHTRT_CHECK(add, "Unable to create Add layer from node: " << *n); - add->setName((util::node_info(n) + "_add").c_str()); - auto add_out = add->getOutput(0); - // SQRT((Var + eps)) - auto sqrt = ctx->net->addUnary(*add_out, nvinfer1::UnaryOperation::kSQRT); - TORCHTRT_CHECK(sqrt, "Unable to create unary(sqrt) from node: " << *n); - sqrt->setName((util::node_info(n) + "_sqrt").c_str()); - auto sqrt_out = sqrt->getOutput(0); - - // (x - E[x]) / sqrt((var + eps)) - auto div = add_elementwise( - ctx, nvinfer1::ElementWiseOperation::kDIV, xsubmean_out, sqrt_out, (util::node_info(n) + "_div").c_str()); - TORCHTRT_CHECK(div, "Unable to create div layer from node: " << *n); - div->setName((util::node_info(n) + "_div").c_str()); - auto div_out = div->getOutput(0); - - if (!args[2].IValue()->isTensor() && !args[3].IValue()->isTensor()) { - ctx->AssociateValueAndTensor(n->outputs()[0], div_out); - return true; - } - - // Remove batch dimension from input shape for expand_size, which will - // be used to create weights for addScaleNd later. - auto expand_size = shape; - expand_size.erase(expand_size.begin(), expand_size.begin() + 1); - - // Set up gamma_weights and beta_weights from gamma_expand and - // beta_expand. - auto gamma_weights = Weights(ctx, at::ones(expand_size)); - auto beta_weights = Weights(ctx, at::zeros(expand_size)); - - if (args[2].IValue()->isTensor()) { - torch::Tensor gamma; - gamma = args[2].unwrapToTensor(); - auto gamma_expand = gamma.expand(expand_size); - gamma_weights = Weights(ctx, gamma_expand); + nvinfer1::ITensor* gamma = nullptr; + if (args[2].IValue()->isNone()) { + auto gamma_torch_tensor = torch::ones(input_shape_vec, torch::TensorOptions().dtype(torch::kFloat32)); + gamma = tensor_to_const(ctx, gamma_torch_tensor); } else { - gamma_weights = Weights(ctx, at::ones(expand_size)); + gamma = args[2].ITensorOrFreeze(ctx); + gamma = broadcast(ctx, n, gamma, input_shape_vec.size(), "gamma"); } - if (args[3].IValue()->isTensor()) { - torch::Tensor beta; - beta = args[3].unwrapToTensor(); - auto beta_expand = beta.expand(expand_size); - beta_weights = Weights(ctx, beta_expand); + nvinfer1::ITensor* beta = nullptr; + if (args[3].IValue()->isNone()) { + auto beta_torch_tensor = torch::zeros(input_shape_vec, torch::TensorOptions().dtype(torch::kFloat32)); + beta = tensor_to_const(ctx, beta_torch_tensor); } else { - beta_weights = Weights(ctx, at::zeros(expand_size)); + beta = args[3].ITensorOrFreeze(ctx); + beta = broadcast(ctx, n, beta, input_shape_vec.size(), "beta"); } - auto power = Weights(ctx, at::ones(expand_size)); - - auto gamma_tensor = ctx->net->addConstant(gamma_weights.shape, gamma_weights.data)->getOutput(0); - auto scale_l = add_elementwise( - ctx, nvinfer1::ElementWiseOperation::kPROD, div_out, gamma_tensor, (util::node_info(n) + "_scale").c_str()); - - auto beta_tensor = ctx->net->addConstant(beta_weights.shape, beta_weights.data)->getOutput(0); - auto shift_l = add_elementwise( - ctx, - nvinfer1::ElementWiseOperation::kSUM, - scale_l->getOutput(0), - beta_tensor, - (util::node_info(n) + "_shift").c_str()); - - auto power_tensor = ctx->net->addConstant(power.shape, power.data)->getOutput(0); - auto power_l = add_elementwise( - ctx, - nvinfer1::ElementWiseOperation::kPOW, - shift_l->getOutput(0), - power_tensor, - (util::node_info(n) + "_power").c_str()); + auto eps = args[4].unwrapToDouble(); - power_l->setName((util::node_info(n) + "_scale_nd").c_str()); - auto power_l_out = power_l->getOutput(0); + auto normalize_layer = ctx->net->addNormalization(*input, *gamma, *beta, axes_mask); + TORCHTRT_CHECK(normalize_layer, "Unable to create layer_norm from node: " << *n); + normalize_layer->setName(util::node_info(n).c_str()); + normalize_layer->setEpsilon(eps); + normalize_layer->setComputePrecision(nvinfer1::DataType::kFLOAT); + auto normalized = normalize_layer->getOutput(0); - ctx->AssociateValueAndTensor(n->outputs()[0], power_l_out); + ctx->AssociateValueAndTensor(n->outputs()[0], normalized); return true; }}); diff --git a/tests/core/conversion/converters/test_layer_norm.cpp b/tests/core/conversion/converters/test_layer_norm.cpp index 9cd64309cc..9ae04aff1d 100644 --- a/tests/core/conversion/converters/test_layer_norm.cpp +++ b/tests/core/conversion/converters/test_layer_norm.cpp @@ -29,8 +29,7 @@ TEST(Converters, ATenLayerNormConvertsCorrectlyLast3DimsNoGammaBeta) { params = torch_tensorrt::core::ir::get_static_params(g->inputs(), {}); auto trt_results = torch_tensorrt::tests::util::RunGraphEngine(g, params, {in}); - ASSERT_TRUE( - torch_tensorrt::tests::util::almostEqual(jit_results[0], trt_results[0].reshape_as(jit_results[0]), 2e-6)); + ASSERT_TRUE(torch_tensorrt::tests::util::almostEqual(jit_results[0], trt_results[0], 2e-6)); } TEST(Converters, ATenLayerNormConvertsCorrectlyLast3Dims) { @@ -60,8 +59,7 @@ TEST(Converters, ATenLayerNormConvertsCorrectlyLast3Dims) { params = torch_tensorrt::core::ir::get_static_params(g->inputs(), {gamma, beta}); auto trt_results = torch_tensorrt::tests::util::RunGraphEngine(g, params, {in}); - ASSERT_TRUE( - torch_tensorrt::tests::util::almostEqual(jit_results[0], trt_results[0].reshape_as(jit_results[0]), 2e-6)); + ASSERT_TRUE(torch_tensorrt::tests::util::almostEqual(jit_results[0], trt_results[0], 2e-6)); } TEST(Converters, ATenLayerNormConvertsCorrectlyLast2Dims) { @@ -90,8 +88,7 @@ TEST(Converters, ATenLayerNormConvertsCorrectlyLast2Dims) { params = torch_tensorrt::core::ir::get_static_params(g->inputs(), {gamma, beta}); auto trt_results = torch_tensorrt::tests::util::RunGraphEngine(g, params, {in}); - ASSERT_TRUE( - torch_tensorrt::tests::util::almostEqual(jit_results[0], trt_results[0].reshape_as(jit_results[0]), 2e-6)); + ASSERT_TRUE(torch_tensorrt::tests::util::almostEqual(jit_results[0], trt_results[0], 2e-6)); } TEST(Converters, ATenLayerNormConvertsCorrectlyLast1Dims) { @@ -119,8 +116,7 @@ TEST(Converters, ATenLayerNormConvertsCorrectlyLast1Dims) { params = torch_tensorrt::core::ir::get_static_params(g->inputs(), {gamma, beta}); auto trt_results = torch_tensorrt::tests::util::RunGraphEngine(g, params, {in}); - ASSERT_TRUE( - torch_tensorrt::tests::util::almostEqual(jit_results[0], trt_results[0].reshape_as(jit_results[0]), 2e-6)); + ASSERT_TRUE(torch_tensorrt::tests::util::almostEqual(jit_results[0], trt_results[0], 2e-6)); } TEST(Converters, ATenLayerNormConvertsCorrectly3dInput1dNormalizedShape) { @@ -134,7 +130,6 @@ TEST(Converters, ATenLayerNormConvertsCorrectly3dInput1dNormalizedShape) { %8 : float = prim::Constant[value=1.0000000000000001e-05]() %9 : Tensor = aten::layer_norm(%0, %4, %gamma, %beta, %8, %7) return (%9))IR"; - auto g = std::make_shared(); torch::jit::parseIR(graph, g.get()); @@ -148,6 +143,5 @@ TEST(Converters, ATenLayerNormConvertsCorrectly3dInput1dNormalizedShape) { params = torch_tensorrt::core::ir::get_static_params(g->inputs(), {gamma, beta}); auto trt_results = torch_tensorrt::tests::util::RunGraphEngine(g, params, {in}); - ASSERT_TRUE( - torch_tensorrt::tests::util::almostEqual(jit_results[0], trt_results[0].reshape_as(jit_results[0]), 2e-6)); + ASSERT_TRUE(torch_tensorrt::tests::util::almostEqual(jit_results[0], trt_results[0], 2e-6)); }