@@ -326,31 +326,27 @@ auto element_wise_registrations TORCHTRT_UNUSED =
326
326
} else if (rounding_mode == " trunc" ) {
327
327
// trunc = floor(abs(div)) * sign(div)
328
328
auto tmp_div = add_elementwise (ctx, nvinfer1::ElementWiseOperation::kDIV , self, other, " tmp_div" );
329
- auto abs = add_absolute_value (ctx, n, tmp_div->getOutput (0 ), util::node_info (n) + " _absolute_val" );
329
+ auto abs = add_abs (ctx, n, tmp_div->getOutput (0 ), util::node_info (n) + " _absolute_val" );
330
330
331
331
// In this case, we allow the floor unary on non-TRT Unary types, as it is needed for this
332
332
// specific function. Floor applied to non-float types equates to identity
333
- nvinfer1::ILayer * floor;
334
- if ((abs-> getOutput ( 0 )-> getType () == nvinfer1::DataType:: kINT32 ) ||
335
- ( abs->getOutput ( 0 ) ->getType () == nvinfer1::DataType::kBOOL )) {
336
- LOG_GRAPH (
337
- " Tensor is of unsupported type " << abs->getOutput ( 0 )-> getType ()
333
+ nvinfer1::ITensor * floor;
334
+
335
+ if (( abs->getType () == nvinfer1::DataType:: kINT32 ) || (abs ->getType () == nvinfer1::DataType::kBOOL )) {
336
+ LOG_DEBUG (
337
+ " Tensor is of unsupported type " << abs->getType ()
338
338
<< " for IUnaryLayer::kFLOOR. Using identity instead." );
339
- floor = ctx->net ->addIdentity (*abs->getOutput (0 ));
340
- TORCHTRT_CHECK (floor, " Unable to create identity layer from node: " << *n);
339
+ floor = abs;
341
340
} else {
342
- floor = ctx->net ->addUnary (*abs->getOutput (0 ), nvinfer1::UnaryOperation::kFLOOR );
343
- TORCHTRT_CHECK (floor, " Unable to create floor layer from node: " << *n);
341
+ auto floor_layer = ctx->net ->addUnary (*abs, nvinfer1::UnaryOperation::kFLOOR );
342
+ TORCHTRT_CHECK (floor_layer, " Unable to create floor layer from node: " << *n);
343
+ floor_layer->setName ((util::node_info (n) + " _floor" ).c_str ());
344
+ floor = floor_layer->getOutput (0 );
344
345
}
345
- floor->setName ((util::node_info (n) + " _floor" ).c_str ());
346
346
347
347
auto sign = ctx->net ->addUnary (*tmp_div->getOutput (0 ), nvinfer1::UnaryOperation::kSIGN );
348
348
div = add_elementwise (
349
- ctx,
350
- nvinfer1::ElementWiseOperation::kPROD ,
351
- floor->getOutput (0 ),
352
- sign->getOutput (0 ),
353
- util::node_info (n));
349
+ ctx, nvinfer1::ElementWiseOperation::kPROD , floor, sign->getOutput (0 ), util::node_info (n));
354
350
} else {
355
351
div = add_elementwise (ctx, nvinfer1::ElementWiseOperation::kDIV , self, other, util::node_info (n));
356
352
}
0 commit comments