Skip to content

Commit 046b6dd

Browse files
committed
Address comments on PR
- Update utility name to add_abs for conciseness - Refactor absolute value utility to return ITensor* - Update logging level for certain debug messages
1 parent 01ee345 commit 046b6dd

File tree

4 files changed

+19
-21
lines changed

4 files changed

+19
-21
lines changed

core/conversion/converters/converter_util.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -156,7 +156,7 @@ nvinfer1::ILayer* add_elementwise(
156156
return ele;
157157
}
158158

159-
nvinfer1::ILayer* add_absolute_value(
159+
nvinfer1::ITensor* add_abs(
160160
ConversionCtx* ctx,
161161
const torch::jit::Node* n,
162162
nvinfer1::ITensor* self,
@@ -185,7 +185,7 @@ nvinfer1::ILayer* add_absolute_value(
185185
TORCHTRT_CHECK(absolute_value_layer, "Unable to create max layer from node: " << *n);
186186
}
187187

188-
return absolute_value_layer;
188+
return absolute_value_layer->getOutput(0);
189189
}
190190

191191
nvinfer1::ITensor* applyIdentityOp(ConversionCtx* ctx, nvinfer1::ITensor* tensor, const std::string& tensor_name) {

core/conversion/converters/converter_util.h

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -35,14 +35,16 @@ nvinfer1::ITensor* addUnpadding(
3535
bool trailing = true,
3636
bool use_zeros = true);
3737

38+
// TODO: Change add_elementwise schema to output nvinfer1::ITensor* instead, for consistency with other utils
39+
// Need to change schema and usage in all calling functions
3840
nvinfer1::ILayer* add_elementwise(
3941
ConversionCtx* ctx,
4042
nvinfer1::ElementWiseOperation op,
4143
nvinfer1::ITensor* self,
4244
nvinfer1::ITensor* other,
4345
const std::string& name);
4446

45-
nvinfer1::ILayer* add_absolute_value(
47+
nvinfer1::ITensor* add_abs(
4648
ConversionCtx* ctx,
4749
const torch::jit::Node* n,
4850
nvinfer1::ITensor* self,

core/conversion/converters/impl/element_wise.cpp

Lines changed: 12 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -326,31 +326,27 @@ auto element_wise_registrations TORCHTRT_UNUSED =
326326
} else if (rounding_mode == "trunc") {
327327
// trunc = floor(abs(div)) * sign(div)
328328
auto tmp_div = add_elementwise(ctx, nvinfer1::ElementWiseOperation::kDIV, self, other, "tmp_div");
329-
auto abs = add_absolute_value(ctx, n, tmp_div->getOutput(0), util::node_info(n) + "_absolute_val");
329+
auto abs = add_abs(ctx, n, tmp_div->getOutput(0), util::node_info(n) + "_absolute_val");
330330

331331
// In this case, we allow the floor unary on non-TRT Unary types, as it is needed for this
332332
// specific function. Floor applied to non-float types equates to identity
333-
nvinfer1::ILayer* floor;
334-
if ((abs->getOutput(0)->getType() == nvinfer1::DataType::kINT32) ||
335-
(abs->getOutput(0)->getType() == nvinfer1::DataType::kBOOL)) {
336-
LOG_GRAPH(
337-
"Tensor is of unsupported type " << abs->getOutput(0)->getType()
333+
nvinfer1::ITensor* floor;
334+
335+
if ((abs->getType() == nvinfer1::DataType::kINT32) || (abs->getType() == nvinfer1::DataType::kBOOL)) {
336+
LOG_DEBUG(
337+
"Tensor is of unsupported type " << abs->getType()
338338
<< " for IUnaryLayer::kFLOOR. Using identity instead.");
339-
floor = ctx->net->addIdentity(*abs->getOutput(0));
340-
TORCHTRT_CHECK(floor, "Unable to create identity layer from node: " << *n);
339+
floor = abs;
341340
} else {
342-
floor = ctx->net->addUnary(*abs->getOutput(0), nvinfer1::UnaryOperation::kFLOOR);
343-
TORCHTRT_CHECK(floor, "Unable to create floor layer from node: " << *n);
341+
auto floor_layer = ctx->net->addUnary(*abs, nvinfer1::UnaryOperation::kFLOOR);
342+
TORCHTRT_CHECK(floor_layer, "Unable to create floor layer from node: " << *n);
343+
floor_layer->setName((util::node_info(n) + "_floor").c_str());
344+
floor = floor_layer->getOutput(0);
344345
}
345-
floor->setName((util::node_info(n) + "_floor").c_str());
346346

347347
auto sign = ctx->net->addUnary(*tmp_div->getOutput(0), nvinfer1::UnaryOperation::kSIGN);
348348
div = add_elementwise(
349-
ctx,
350-
nvinfer1::ElementWiseOperation::kPROD,
351-
floor->getOutput(0),
352-
sign->getOutput(0),
353-
util::node_info(n));
349+
ctx, nvinfer1::ElementWiseOperation::kPROD, floor, sign->getOutput(0), util::node_info(n));
354350
} else {
355351
div = add_elementwise(ctx, nvinfer1::ElementWiseOperation::kDIV, self, other, util::node_info(n));
356352
}

core/conversion/converters/impl/unary.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -13,8 +13,8 @@ namespace {
1313
auto abs_registration TORCHTRT_UNUSED = RegisterNodeConversionPatterns().pattern(
1414
{"aten::abs(Tensor self) -> Tensor", [](ConversionCtx* ctx, const torch::jit::Node* n, args& args) -> bool {
1515
auto in = args[0].ITensorOrFreeze(ctx);
16-
auto abs_layer = add_absolute_value(ctx, n, in, util::node_info(n));
17-
auto out_tensor = ctx->AssociateValueAndTensor(n->outputs()[0], abs_layer->getOutput(0));
16+
auto abs_tensor = add_abs(ctx, n, in, util::node_info(n));
17+
auto out_tensor = ctx->AssociateValueAndTensor(n->outputs()[0], abs_tensor);
1818
LOG_DEBUG("Output tensor shape: " << out_tensor->getDimensions());
1919
return true;
2020
}});

0 commit comments

Comments
 (0)