Skip to content

Commit 1395110

Browse files
authored
Merge pull request #1283 from mfeliz-cruise/michael.feliz/constant_pad_nd_int
fix: Add int support to constant_pad_nd
2 parents 2fc413b + 911ab5b commit 1395110

File tree

2 files changed

+29
-13
lines changed

2 files changed

+29
-13
lines changed

core/conversion/converters/impl/constant_pad.cpp

Lines changed: 6 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,8 @@ auto constant_pad_registrations TORCHTRT_UNUSED = RegisterNodeConversionPatterns
2121
auto padding = args[1].unwrapToIntList().vec();
2222
int64_t padSize = padding.size();
2323
auto value = args[2].unwrapToScalar().to<float>();
24-
24+
at::Tensor value_tensor = torch::tensor(value, util::TRTDataTypeToScalarType(in->getType()));
25+
auto valueTensor = tensor_to_const(ctx, value_tensor);
2526
TORCHTRT_CHECK(padSize % 2 == 0, "Length of pad must be even but instead it equals " << padSize);
2627

2728
int64_t l_pad = padSize / 2;
@@ -55,10 +56,8 @@ auto constant_pad_registrations TORCHTRT_UNUSED = RegisterNodeConversionPatterns
5556
auto fill_layer = ctx->net->addFill(nvinfer1::Dims{1, {1}}, nvinfer1::FillOperation::kLINSPACE);
5657
auto shape_gather_out = ctx->net->addShape(*left_gather_out)->getOutput(0);
5758
fill_layer->setInput(0, *shape_gather_out);
58-
at::Tensor value_tensor = torch::tensor(value, torch::kFloat32);
59-
auto valueTensor = tensor_to_const(ctx, value_tensor);
6059
fill_layer->setInput(1, *valueTensor);
61-
at::Tensor delta_tensor = torch::zeros(inRank);
60+
at::Tensor delta_tensor = torch::zeros(inRank, util::TRTDataTypeToScalarType(in->getType()));
6261
auto deltaTensor = tensor_to_const(ctx, delta_tensor);
6362
fill_layer->setInput(2, *deltaTensor);
6463
auto padTensor = fill_layer->getOutput(0);
@@ -69,10 +68,8 @@ auto constant_pad_registrations TORCHTRT_UNUSED = RegisterNodeConversionPatterns
6968
} else {
7069
inDims.d[axis] = padding[padding_index];
7170
auto fill_layer = ctx->net->addFill(inDims, nvinfer1::FillOperation::kLINSPACE);
72-
at::Tensor value_tensor = torch::tensor(value, torch::kFloat32);
73-
auto valueTensor = tensor_to_const(ctx, value_tensor);
7471
fill_layer->setInput(1, *valueTensor);
75-
at::Tensor delta_tensor = torch::zeros(inRank);
72+
at::Tensor delta_tensor = torch::zeros(inRank, util::TRTDataTypeToScalarType(in->getType()));
7673
auto deltaTensor = tensor_to_const(ctx, delta_tensor);
7774
fill_layer->setInput(2, *deltaTensor);
7875
auto padTensor = fill_layer->getOutput(0);
@@ -112,10 +109,8 @@ auto constant_pad_registrations TORCHTRT_UNUSED = RegisterNodeConversionPatterns
112109
auto fill_layer = ctx->net->addFill(nvinfer1::Dims{1, {1}}, nvinfer1::FillOperation::kLINSPACE);
113110
auto shape_gather_out = ctx->net->addShape(*right_gather_out)->getOutput(0);
114111
fill_layer->setInput(0, *shape_gather_out);
115-
at::Tensor value_tensor = torch::tensor(value, torch::kFloat32);
116-
auto valueTensor = tensor_to_const(ctx, value_tensor);
117112
fill_layer->setInput(1, *valueTensor);
118-
at::Tensor delta_tensor = torch::zeros(inRank);
113+
at::Tensor delta_tensor = torch::zeros(inRank, util::TRTDataTypeToScalarType(in->getType()));
119114
auto deltaTensor = tensor_to_const(ctx, delta_tensor);
120115
fill_layer->setInput(2, *deltaTensor);
121116
auto padTensor = fill_layer->getOutput(0);
@@ -126,10 +121,8 @@ auto constant_pad_registrations TORCHTRT_UNUSED = RegisterNodeConversionPatterns
126121
} else {
127122
inDims.d[axis] = padding[padding_index + 1];
128123
auto fill_layer = ctx->net->addFill(inDims, nvinfer1::FillOperation::kLINSPACE);
129-
at::Tensor value_tensor = torch::tensor(value, torch::kFloat32);
130-
auto valueTensor = tensor_to_const(ctx, value_tensor);
131124
fill_layer->setInput(1, *valueTensor);
132-
at::Tensor delta_tensor = torch::zeros(inRank);
125+
at::Tensor delta_tensor = torch::zeros(inRank, util::TRTDataTypeToScalarType(in->getType()));
133126
auto deltaTensor = tensor_to_const(ctx, delta_tensor);
134127
fill_layer->setInput(2, *deltaTensor);
135128
auto padTensor = fill_layer->getOutput(0);

tests/core/conversion/converters/test_constant_pad.cpp

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,29 @@ TEST(Converters, ATenConstantPad1dTensorConvertsCorrectly) {
2828
torch_tensorrt::tests::util::almostEqual(jit_results[0], trt_results[0].reshape_as(jit_results[0]), 2e-6));
2929
}
3030

31+
TEST(Converters, ATenConstantPad1dIntTensorConvertsCorrectly) {
32+
const auto graph = R"IR(
33+
graph(%0 : Tensor):
34+
%1 : int[] = prim::Constant[value=[2, 3]]()
35+
%2 : Scalar = prim::Constant[value=2]()
36+
%3 : Tensor = aten::constant_pad_nd(%0, %1, %2)
37+
return (%3))IR";
38+
39+
auto g = std::make_shared<torch::jit::Graph>();
40+
torch::jit::parseIR(graph, g.get());
41+
42+
auto in1 = at::randint(1, 10, {1, 3, 4}, {at::kCUDA}).toType(at::kInt);
43+
44+
auto params = torch_tensorrt::core::ir::get_static_params(g->inputs(), {});
45+
auto jit_results = torch_tensorrt::tests::util::RunGraph(g, params, {in1});
46+
47+
params = torch_tensorrt::core::ir::get_static_params(g->inputs(), {});
48+
auto trt_results = torch_tensorrt::tests::util::RunGraphEngine(g, params, {in1});
49+
50+
ASSERT_TRUE(
51+
torch_tensorrt::tests::util::almostEqual(jit_results[0], trt_results[0].reshape_as(jit_results[0]), 2e-6));
52+
}
53+
3154
TEST(Converters, ATenConstantPad1dRightZeroTensorConvertsCorrectly) {
3255
const auto graph = R"IR(
3356
graph(%0 : Tensor):

0 commit comments

Comments
 (0)