From eccca5f467d13dc2e1d7aa30fb1d3f688602ec24 Mon Sep 17 00:00:00 2001 From: Michael Feliz Date: Tue, 5 Sep 2023 14:48:04 -0700 Subject: [PATCH 1/2] Rewrite layer norm converter for perf --- .../conversion/converters/impl/layer_norm.cpp | 65 +++++-------------- 1 file changed, 15 insertions(+), 50 deletions(-) diff --git a/core/conversion/converters/impl/layer_norm.cpp b/core/conversion/converters/impl/layer_norm.cpp index 0c00ee2c4d..c9467777f4 100644 --- a/core/conversion/converters/impl/layer_norm.cpp +++ b/core/conversion/converters/impl/layer_norm.cpp @@ -88,60 +88,25 @@ auto layer_norm_registrations TORCHTRT_UNUSED = RegisterNodeConversionPatterns() return true; } - // Remove batch dimension from input shape for expand_size, which will - // be used to create weights for addScaleNd later. - auto expand_size = shape; - expand_size.erase(expand_size.begin(), expand_size.begin() + 1); - - // Set up gamma_weights and beta_weights from gamma_expand and - // beta_expand. - auto gamma_weights = Weights(ctx, at::ones(expand_size)); - auto beta_weights = Weights(ctx, at::zeros(expand_size)); - - if (args[2].IValue()->isTensor()) { - torch::Tensor gamma; - gamma = args[2].unwrapToTensor(); - auto gamma_expand = gamma.expand(expand_size); - gamma_weights = Weights(ctx, gamma_expand); - } else { - gamma_weights = Weights(ctx, at::ones(expand_size)); + auto normalized = div_out; + + //gamma + if (args[2].IValue()->isTensor()){ + auto gamma = args[2].ITensorOrFreeze(ctx); + auto gamma_prod = add_elementwise( + ctx, nvinfer1::ElementWiseOperation::kPROD, normalized, gamma, (util::node_info(n) + "_gamma").c_str()); + normalized = gamma_prod->getOutput(0); } - if (args[3].IValue()->isTensor()) { - torch::Tensor beta; - beta = args[3].unwrapToTensor(); - auto beta_expand = beta.expand(expand_size); - beta_weights = Weights(ctx, beta_expand); - } else { - beta_weights = Weights(ctx, at::zeros(expand_size)); + //beta + if (args[3].IValue()->isTensor()){ + auto beta = args[3].ITensorOrFreeze(ctx); + auto beta_sum = add_elementwise( + ctx, nvinfer1::ElementWiseOperation::kSUM, normalized, beta, (util::node_info(n) + "_beta").c_str()); + normalized = beta_sum->getOutput(0); } - auto power = Weights(ctx, at::ones(expand_size)); - - auto gamma_tensor = ctx->net->addConstant(gamma_weights.shape, gamma_weights.data)->getOutput(0); - auto scale_l = add_elementwise( - ctx, nvinfer1::ElementWiseOperation::kPROD, div_out, gamma_tensor, (util::node_info(n) + "_scale").c_str()); - - auto beta_tensor = ctx->net->addConstant(beta_weights.shape, beta_weights.data)->getOutput(0); - auto shift_l = add_elementwise( - ctx, - nvinfer1::ElementWiseOperation::kSUM, - scale_l->getOutput(0), - beta_tensor, - (util::node_info(n) + "_shift").c_str()); - - auto power_tensor = ctx->net->addConstant(power.shape, power.data)->getOutput(0); - auto power_l = add_elementwise( - ctx, - nvinfer1::ElementWiseOperation::kPOW, - shift_l->getOutput(0), - power_tensor, - (util::node_info(n) + "_power").c_str()); - - power_l->setName((util::node_info(n) + "_scale_nd").c_str()); - auto power_l_out = power_l->getOutput(0); - - ctx->AssociateValueAndTensor(n->outputs()[0], power_l_out); + ctx->AssociateValueAndTensor(n->outputs()[0], normalized); return true; }}); From 5bf77112f7331d806c5d326a7994e9af557c4f78 Mon Sep 17 00:00:00 2001 From: Michael Feliz Date: Tue, 5 Sep 2023 14:54:01 -0700 Subject: [PATCH 2/2] Improve performance of converted layer_norm ops --- core/conversion/converters/impl/layer_norm.cpp | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/core/conversion/converters/impl/layer_norm.cpp b/core/conversion/converters/impl/layer_norm.cpp index c9467777f4..f09fc84514 100644 --- a/core/conversion/converters/impl/layer_norm.cpp +++ b/core/conversion/converters/impl/layer_norm.cpp @@ -90,16 +90,16 @@ auto layer_norm_registrations TORCHTRT_UNUSED = RegisterNodeConversionPatterns() auto normalized = div_out; - //gamma - if (args[2].IValue()->isTensor()){ + // gamma + if (args[2].IValue()->isTensor()) { auto gamma = args[2].ITensorOrFreeze(ctx); auto gamma_prod = add_elementwise( ctx, nvinfer1::ElementWiseOperation::kPROD, normalized, gamma, (util::node_info(n) + "_gamma").c_str()); normalized = gamma_prod->getOutput(0); } - //beta - if (args[3].IValue()->isTensor()){ + // beta + if (args[3].IValue()->isTensor()) { auto beta = args[3].ITensorOrFreeze(ctx); auto beta_sum = add_elementwise( ctx, nvinfer1::ElementWiseOperation::kSUM, normalized, beta, (util::node_info(n) + "_beta").c_str());