diff --git a/core/conversion/converters/impl/layer_norm.cpp b/core/conversion/converters/impl/layer_norm.cpp index 0c00ee2c4d..f09fc84514 100644 --- a/core/conversion/converters/impl/layer_norm.cpp +++ b/core/conversion/converters/impl/layer_norm.cpp @@ -88,60 +88,25 @@ auto layer_norm_registrations TORCHTRT_UNUSED = RegisterNodeConversionPatterns() return true; } - // Remove batch dimension from input shape for expand_size, which will - // be used to create weights for addScaleNd later. - auto expand_size = shape; - expand_size.erase(expand_size.begin(), expand_size.begin() + 1); - - // Set up gamma_weights and beta_weights from gamma_expand and - // beta_expand. - auto gamma_weights = Weights(ctx, at::ones(expand_size)); - auto beta_weights = Weights(ctx, at::zeros(expand_size)); + auto normalized = div_out; + // gamma if (args[2].IValue()->isTensor()) { - torch::Tensor gamma; - gamma = args[2].unwrapToTensor(); - auto gamma_expand = gamma.expand(expand_size); - gamma_weights = Weights(ctx, gamma_expand); - } else { - gamma_weights = Weights(ctx, at::ones(expand_size)); + auto gamma = args[2].ITensorOrFreeze(ctx); + auto gamma_prod = add_elementwise( + ctx, nvinfer1::ElementWiseOperation::kPROD, normalized, gamma, (util::node_info(n) + "_gamma").c_str()); + normalized = gamma_prod->getOutput(0); } + // beta if (args[3].IValue()->isTensor()) { - torch::Tensor beta; - beta = args[3].unwrapToTensor(); - auto beta_expand = beta.expand(expand_size); - beta_weights = Weights(ctx, beta_expand); - } else { - beta_weights = Weights(ctx, at::zeros(expand_size)); + auto beta = args[3].ITensorOrFreeze(ctx); + auto beta_sum = add_elementwise( + ctx, nvinfer1::ElementWiseOperation::kSUM, normalized, beta, (util::node_info(n) + "_beta").c_str()); + normalized = beta_sum->getOutput(0); } - auto power = Weights(ctx, at::ones(expand_size)); - - auto gamma_tensor = ctx->net->addConstant(gamma_weights.shape, gamma_weights.data)->getOutput(0); - auto scale_l = add_elementwise( - ctx, nvinfer1::ElementWiseOperation::kPROD, div_out, gamma_tensor, (util::node_info(n) + "_scale").c_str()); - - auto beta_tensor = ctx->net->addConstant(beta_weights.shape, beta_weights.data)->getOutput(0); - auto shift_l = add_elementwise( - ctx, - nvinfer1::ElementWiseOperation::kSUM, - scale_l->getOutput(0), - beta_tensor, - (util::node_info(n) + "_shift").c_str()); - - auto power_tensor = ctx->net->addConstant(power.shape, power.data)->getOutput(0); - auto power_l = add_elementwise( - ctx, - nvinfer1::ElementWiseOperation::kPOW, - shift_l->getOutput(0), - power_tensor, - (util::node_info(n) + "_power").c_str()); - - power_l->setName((util::node_info(n) + "_scale_nd").c_str()); - auto power_l_out = power_l->getOutput(0); - - ctx->AssociateValueAndTensor(n->outputs()[0], power_l_out); + ctx->AssociateValueAndTensor(n->outputs()[0], normalized); return true; }});