Skip to content

chore: Update layer_norm converter to use INormalizationLayer #2509

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
182 changes: 64 additions & 118 deletions core/conversion/converters/impl/layer_norm.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -10,138 +10,84 @@ namespace converters {
namespace impl {
namespace {

nvinfer1::ITensor* broadcast(
ConversionCtx* ctx,
const torch::jit::Node* n,
nvinfer1::ITensor* to_broadcast,
const int nbDims,
const std::string& tag) {
auto to_broadcast_nbdims = to_broadcast->getDimensions().nbDims;
TORCHTRT_CHECK(to_broadcast_nbdims <= nbDims, "Cannot broadcast tensor with more dimensions than the target");
if (to_broadcast_nbdims == nbDims) {
return to_broadcast;
}
auto shape_layer = ctx->net->addShape(*to_broadcast);
TORCHTRT_CHECK(shape_layer, "Unable to create shape layer from node: " << *n);
shape_layer->setName((util::node_info(n) + "_shape_" + tag).c_str());
auto shape_layer_out = shape_layer->getOutput(0);

auto extra_dims_tensor = torch::ones({nbDims - to_broadcast_nbdims}, torch::TensorOptions().dtype(torch::kInt32));
auto extra_dims_itensor = tensor_to_const(ctx, extra_dims_tensor);

std::vector<nvinfer1::ITensor*> to_concat = {extra_dims_itensor, shape_layer_out};
auto concat_layer = ctx->net->addConcatenation(to_concat.data(), to_concat.size());
TORCHTRT_CHECK(concat_layer, "Unable to create concat layer from node: " << *n);
concat_layer->setName((util::node_info(n) + "_concat_" + tag).c_str());
auto target_shape = concat_layer->getOutput(0);

auto shuffle_layer = ctx->net->addShuffle(*to_broadcast);
TORCHTRT_CHECK(shuffle_layer, "Unable to create shuffle layer from node: " << *n);
shuffle_layer->setName((util::node_info(n) + "_shuffle_" + tag).c_str());
shuffle_layer->setInput(1, *target_shape);
auto output = shuffle_layer->getOutput(0);
LOG_DEBUG(
"Broadcast " << tag << " to shape: " << output->getDimensions() << " from " << to_broadcast->getDimensions());
return output;
}

auto layer_norm_registrations TORCHTRT_UNUSED = RegisterNodeConversionPatterns().pattern({
R"SIG(aten::layer_norm(Tensor input, int[] normalized_shape, Tensor? gamma, Tensor? beta,
float eps, bool cudnn_enabled) -> (Tensor))SIG",
[](ConversionCtx* ctx, const torch::jit::Node* n, args& args) -> bool {
auto input = args[0].ITensor(); // assumes non-static input Tensor
auto orig_shape = input->getDimensions();
auto shape = util::toVec(orig_shape);

/* Layer_Norm normalizes over last N dimensions.
normalizaed_shape could be (C,H,W), (H,W), or (W). */
// This could be an IntList or ITensorList. We only need the size of this list.
auto normalized_shape = args[1].IValue()->toList();

// Unwrap eps.
auto eps = args[4].unwrapToDouble();

LOG_DEBUG("cudnn disregarded");

// Set up axis_ask for E[x].
uint32_t axis_mask = 0;
for (size_t i = 0; i < normalized_shape.size(); i++) {
axis_mask |= 1 << (shape.size() - i - 1);
auto input = args[0].ITensorOrFreeze(ctx);
auto input_shape = input->getDimensions();
auto input_shape_vec = util::toVec(input_shape);
auto normalized_shape = args[1].unwrapToIntList();
auto normalized_shape_vec = util::toVec(util::toDims(normalized_shape));
auto axis = input_shape_vec.size() - normalized_shape_vec.size();
uint32_t axes_mask = 0;
for (size_t i = axis; i < input_shape_vec.size(); i++) {
axes_mask |= 1 << i;
}
LOG_DEBUG("Axis Mask for E[x]" << std::bitset<32>(axis_mask));

// E[x]
auto mean_expected = ctx->net->addReduce(*input, nvinfer1::ReduceOperation::kAVG, axis_mask, true);
TORCHTRT_CHECK(mean_expected, "Unable to create mean_expected from node: " << *n);
mean_expected->setName((util::node_info(n) + "_mean_expected").c_str());
auto mean_expected_out = mean_expected->getOutput(0);

// X-E[x]
auto sub = add_elementwise(
ctx, nvinfer1::ElementWiseOperation::kSUB, input, mean_expected_out, (util::node_info(n) + "_sub").c_str());
TORCHTRT_CHECK(sub, "Unable to create Sub layer from node: " << *n);
sub->setName((util::node_info(n) + "_sub").c_str());
auto xsubmean_out = sub->getOutput(0);

// Variance = mean(pow(xsubmean,2))
float pow_scalar = 2;
auto exponent = tensor_to_const(ctx, torch::tensor({pow_scalar}));
auto pow = add_elementwise(
ctx, nvinfer1::ElementWiseOperation::kPOW, xsubmean_out, exponent, (util::node_info(n) + "_pow").c_str());
TORCHTRT_CHECK(pow, "Unable to create Pow layer from node: " << *n);
pow->setName((util::node_info(n) + "_pow").c_str());
auto pow_out = pow->getOutput(0);

auto mean_var = ctx->net->addReduce(*pow_out, nvinfer1::ReduceOperation::kAVG, axis_mask, true);
TORCHTRT_CHECK(mean_var, "Unable to create mean_var from node: " << *n);
mean_var->setName((util::node_info(n) + "_mean_var").c_str());
auto mean_var_out = mean_var->getOutput(0);

// Variance + eps
auto eps_tensor = tensor_to_const(ctx, torch::tensor({eps}));
auto add = add_elementwise(
ctx, nvinfer1::ElementWiseOperation::kSUM, mean_var_out, eps_tensor, (util::node_info(n) + "_add").c_str());
TORCHTRT_CHECK(add, "Unable to create Add layer from node: " << *n);
add->setName((util::node_info(n) + "_add").c_str());
auto add_out = add->getOutput(0);

// SQRT((Var + eps))
auto sqrt = ctx->net->addUnary(*add_out, nvinfer1::UnaryOperation::kSQRT);
TORCHTRT_CHECK(sqrt, "Unable to create unary(sqrt) from node: " << *n);
sqrt->setName((util::node_info(n) + "_sqrt").c_str());
auto sqrt_out = sqrt->getOutput(0);

// (x - E[x]) / sqrt((var + eps))
auto div = add_elementwise(
ctx, nvinfer1::ElementWiseOperation::kDIV, xsubmean_out, sqrt_out, (util::node_info(n) + "_div").c_str());
TORCHTRT_CHECK(div, "Unable to create div layer from node: " << *n);
div->setName((util::node_info(n) + "_div").c_str());
auto div_out = div->getOutput(0);

if (!args[2].IValue()->isTensor() && !args[3].IValue()->isTensor()) {
ctx->AssociateValueAndTensor(n->outputs()[0], div_out);
return true;
}

// Remove batch dimension from input shape for expand_size, which will
// be used to create weights for addScaleNd later.
auto expand_size = shape;
expand_size.erase(expand_size.begin(), expand_size.begin() + 1);

// Set up gamma_weights and beta_weights from gamma_expand and
// beta_expand.
auto gamma_weights = Weights(ctx, at::ones(expand_size));
auto beta_weights = Weights(ctx, at::zeros(expand_size));

if (args[2].IValue()->isTensor()) {
torch::Tensor gamma;
gamma = args[2].unwrapToTensor();
auto gamma_expand = gamma.expand(expand_size);
gamma_weights = Weights(ctx, gamma_expand);
nvinfer1::ITensor* gamma = nullptr;
if (args[2].IValue()->isNone()) {
auto gamma_torch_tensor = torch::ones(input_shape_vec, torch::TensorOptions().dtype(torch::kFloat32));
gamma = tensor_to_const(ctx, gamma_torch_tensor);
} else {
gamma_weights = Weights(ctx, at::ones(expand_size));
gamma = args[2].ITensorOrFreeze(ctx);
gamma = broadcast(ctx, n, gamma, input_shape_vec.size(), "gamma");
}

if (args[3].IValue()->isTensor()) {
torch::Tensor beta;
beta = args[3].unwrapToTensor();
auto beta_expand = beta.expand(expand_size);
beta_weights = Weights(ctx, beta_expand);
nvinfer1::ITensor* beta = nullptr;
if (args[3].IValue()->isNone()) {
auto beta_torch_tensor = torch::zeros(input_shape_vec, torch::TensorOptions().dtype(torch::kFloat32));
beta = tensor_to_const(ctx, beta_torch_tensor);
} else {
beta_weights = Weights(ctx, at::zeros(expand_size));
beta = args[3].ITensorOrFreeze(ctx);
beta = broadcast(ctx, n, beta, input_shape_vec.size(), "beta");
}

auto power = Weights(ctx, at::ones(expand_size));

auto gamma_tensor = ctx->net->addConstant(gamma_weights.shape, gamma_weights.data)->getOutput(0);
auto scale_l = add_elementwise(
ctx, nvinfer1::ElementWiseOperation::kPROD, div_out, gamma_tensor, (util::node_info(n) + "_scale").c_str());

auto beta_tensor = ctx->net->addConstant(beta_weights.shape, beta_weights.data)->getOutput(0);
auto shift_l = add_elementwise(
ctx,
nvinfer1::ElementWiseOperation::kSUM,
scale_l->getOutput(0),
beta_tensor,
(util::node_info(n) + "_shift").c_str());

auto power_tensor = ctx->net->addConstant(power.shape, power.data)->getOutput(0);
auto power_l = add_elementwise(
ctx,
nvinfer1::ElementWiseOperation::kPOW,
shift_l->getOutput(0),
power_tensor,
(util::node_info(n) + "_power").c_str());
auto eps = args[4].unwrapToDouble();

power_l->setName((util::node_info(n) + "_scale_nd").c_str());
auto power_l_out = power_l->getOutput(0);
auto normalize_layer = ctx->net->addNormalization(*input, *gamma, *beta, axes_mask);
TORCHTRT_CHECK(normalize_layer, "Unable to create layer_norm from node: " << *n);
normalize_layer->setName(util::node_info(n).c_str());
normalize_layer->setEpsilon(eps);
normalize_layer->setComputePrecision(nvinfer1::DataType::kFLOAT);
auto normalized = normalize_layer->getOutput(0);

ctx->AssociateValueAndTensor(n->outputs()[0], power_l_out);
ctx->AssociateValueAndTensor(n->outputs()[0], normalized);
return true;
}});

Expand Down
16 changes: 5 additions & 11 deletions tests/core/conversion/converters/test_layer_norm.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -29,8 +29,7 @@ TEST(Converters, ATenLayerNormConvertsCorrectlyLast3DimsNoGammaBeta) {
params = torch_tensorrt::core::ir::get_static_params(g->inputs(), {});
auto trt_results = torch_tensorrt::tests::util::RunGraphEngine(g, params, {in});

ASSERT_TRUE(
torch_tensorrt::tests::util::almostEqual(jit_results[0], trt_results[0].reshape_as(jit_results[0]), 2e-6));
ASSERT_TRUE(torch_tensorrt::tests::util::almostEqual(jit_results[0], trt_results[0], 2e-6));
}

TEST(Converters, ATenLayerNormConvertsCorrectlyLast3Dims) {
Expand Down Expand Up @@ -60,8 +59,7 @@ TEST(Converters, ATenLayerNormConvertsCorrectlyLast3Dims) {
params = torch_tensorrt::core::ir::get_static_params(g->inputs(), {gamma, beta});
auto trt_results = torch_tensorrt::tests::util::RunGraphEngine(g, params, {in});

ASSERT_TRUE(
torch_tensorrt::tests::util::almostEqual(jit_results[0], trt_results[0].reshape_as(jit_results[0]), 2e-6));
ASSERT_TRUE(torch_tensorrt::tests::util::almostEqual(jit_results[0], trt_results[0], 2e-6));
}

TEST(Converters, ATenLayerNormConvertsCorrectlyLast2Dims) {
Expand Down Expand Up @@ -90,8 +88,7 @@ TEST(Converters, ATenLayerNormConvertsCorrectlyLast2Dims) {
params = torch_tensorrt::core::ir::get_static_params(g->inputs(), {gamma, beta});
auto trt_results = torch_tensorrt::tests::util::RunGraphEngine(g, params, {in});

ASSERT_TRUE(
torch_tensorrt::tests::util::almostEqual(jit_results[0], trt_results[0].reshape_as(jit_results[0]), 2e-6));
ASSERT_TRUE(torch_tensorrt::tests::util::almostEqual(jit_results[0], trt_results[0], 2e-6));
}

TEST(Converters, ATenLayerNormConvertsCorrectlyLast1Dims) {
Expand Down Expand Up @@ -119,8 +116,7 @@ TEST(Converters, ATenLayerNormConvertsCorrectlyLast1Dims) {
params = torch_tensorrt::core::ir::get_static_params(g->inputs(), {gamma, beta});
auto trt_results = torch_tensorrt::tests::util::RunGraphEngine(g, params, {in});

ASSERT_TRUE(
torch_tensorrt::tests::util::almostEqual(jit_results[0], trt_results[0].reshape_as(jit_results[0]), 2e-6));
ASSERT_TRUE(torch_tensorrt::tests::util::almostEqual(jit_results[0], trt_results[0], 2e-6));
}

TEST(Converters, ATenLayerNormConvertsCorrectly3dInput1dNormalizedShape) {
Expand All @@ -134,7 +130,6 @@ TEST(Converters, ATenLayerNormConvertsCorrectly3dInput1dNormalizedShape) {
%8 : float = prim::Constant[value=1.0000000000000001e-05]()
%9 : Tensor = aten::layer_norm(%0, %4, %gamma, %beta, %8, %7)
return (%9))IR";

auto g = std::make_shared<torch::jit::Graph>();
torch::jit::parseIR(graph, g.get());

Expand All @@ -148,6 +143,5 @@ TEST(Converters, ATenLayerNormConvertsCorrectly3dInput1dNormalizedShape) {
params = torch_tensorrt::core::ir::get_static_params(g->inputs(), {gamma, beta});
auto trt_results = torch_tensorrt::tests::util::RunGraphEngine(g, params, {in});

ASSERT_TRUE(
torch_tensorrt::tests::util::almostEqual(jit_results[0], trt_results[0].reshape_as(jit_results[0]), 2e-6));
ASSERT_TRUE(torch_tensorrt::tests::util::almostEqual(jit_results[0], trt_results[0], 2e-6));
}