Skip to content

Commit 151340b

Browse files
committed
feat: support output_padding argument in deconv converter
1 parent f48f040 commit 151340b

File tree

4 files changed

+136
-14
lines changed

4 files changed

+136
-14
lines changed

py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py

Lines changed: 1 addition & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -2447,16 +2447,8 @@ def aten_ops_le(
24472447
)
24482448

24492449

2450-
def conv_param_validator(
2451-
conv_node: Node, settings: Optional[CompilationSettings] = None
2452-
) -> bool:
2453-
2454-
return conv_node.args[7] in ([0], [0, 0], [0, 0, 0])
2455-
2456-
24572450
@dynamo_tensorrt_converter(
24582451
torch.ops.aten.convolution.default,
2459-
capability_validator=conv_param_validator,
24602452
supports_dynamic_shapes=True,
24612453
)
24622454
@enforce_tensor_types(
@@ -2502,7 +2494,7 @@ def aten_ops_convolution(
25022494
stride=args[3],
25032495
padding=args[4],
25042496
dilation=args[5],
2505-
# output_padding=args[7],
2497+
output_padding=args[7],
25062498
groups=args[8],
25072499
)
25082500

py/torch_tensorrt/dynamo/conversion/impl/deconv.py

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
import tensorrt as trt
77
import torch
88
from torch.fx.node import Target
9+
910
from torch_tensorrt.dynamo.conversion import impl
1011
from torch_tensorrt.dynamo.conversion._ConversionContext import ConversionContext
1112
from torch_tensorrt.dynamo.conversion.converter_utils import (
@@ -105,6 +106,9 @@ def deconvNd(
105106
padding = (padding,) if isinstance(padding, int) else padding
106107
stride = (stride,) if isinstance(stride, int) else stride
107108
dilation = (dilation,) if isinstance(dilation, int) else dilation
109+
output_padding = (
110+
(output_padding,) if isinstance(output_padding, int) else output_padding
111+
)
108112

109113
# Expand parameters manually for Conv1D computations
110114
if is_deconv1d:
@@ -113,6 +117,11 @@ def deconvNd(
113117
dilation = (
114118
extend_attr_to_tuple(dilation, 2) if dilation is not None else dilation
115119
)
120+
output_padding = (
121+
(tuple(output_padding) + (0,))
122+
if output_padding is not None
123+
else output_padding
124+
)
116125

117126
set_layer_name(deconv_layer, target, name, source_ir)
118127

@@ -126,6 +135,20 @@ def deconvNd(
126135
if groups is not None:
127136
deconv_layer.num_groups = groups
128137

138+
ndims = len(padding)
139+
pre_padding_values = []
140+
post_padding_values = []
141+
142+
for dim in range(ndims):
143+
pre_padding = padding[dim]
144+
post_padding = padding[dim] - output_padding[dim]
145+
146+
pre_padding_values.append(pre_padding)
147+
post_padding_values.append(post_padding)
148+
149+
deconv_layer.pre_padding = tuple(pre_padding_values)
150+
deconv_layer.post_padding = tuple(post_padding_values)
151+
129152
# Handle quantization cases
130153
if scale is not None and zero_point is not None:
131154
# Assume the dtype of activation is torch.quint8

py/torch_tensorrt/fx/converters/aten_ops_converters.py

Lines changed: 56 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -104,6 +104,62 @@ def aten_ops_batch_norm(
104104
)
105105

106106

107+
@tensorrt_converter(torch.ops.aten.convolution.default)
108+
def aten_ops_convolution(
109+
network: TRTNetwork,
110+
target: Target,
111+
args: Tuple[Argument, ...],
112+
kwargs: Dict[str, Argument],
113+
name: str,
114+
) -> Union[TRTTensor, Sequence[TRTTensor]]:
115+
kwargs_new = {
116+
"input": args[0],
117+
"weight": args[1],
118+
"bias": args[2],
119+
"stride": args[3],
120+
"padding": args[4],
121+
"dilation": args[5],
122+
"groups": args[8],
123+
}
124+
# we do not handle transposed.
125+
if args[6] is True:
126+
raise RuntimeError(f"Target {target} does not support `transposed=True` ")
127+
# we do not handle output_padding.
128+
if args[7] not in ([0], [0, 0], [0, 0, 0]):
129+
raise RuntimeError(f"Target {target} has non-0 output_padding")
130+
131+
if len(kwargs_new["stride"]) == 1:
132+
return convolution.convNd(
133+
network,
134+
target,
135+
source_ir=SourceIR.ATEN,
136+
name=name,
137+
is_conv1d=True,
138+
input_val=kwargs_new["input"],
139+
weight=kwargs_new["weight"],
140+
bias=kwargs_new["bias"],
141+
stride=kwargs_new["stride"],
142+
padding=kwargs_new["padding"],
143+
dilation=kwargs_new["dilation"],
144+
groups=kwargs_new["groups"],
145+
)
146+
else:
147+
return convolution.convNd(
148+
network,
149+
target,
150+
source_ir=SourceIR.ATEN,
151+
name=name,
152+
is_conv1d=False,
153+
input_val=kwargs_new["input"],
154+
weight=kwargs_new["weight"],
155+
bias=kwargs_new["bias"],
156+
stride=kwargs_new["stride"],
157+
padding=kwargs_new["padding"],
158+
dilation=kwargs_new["dilation"],
159+
groups=kwargs_new["groups"],
160+
)
161+
162+
107163
@tensorrt_converter(torch.ops.aten.div.default)
108164
@tensorrt_converter(torch.ops.aten.div.Tensor_mode)
109165
@tensorrt_converter(torch.ops.aten.div.Tensor)

tests/py/dynamo/conversion/test_deconvolution_aten.py

Lines changed: 56 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
import torch
22
from parameterized import param, parameterized
33
from torch.testing._internal.common_utils import run_tests
4+
45
from torch_tensorrt import Input
56

67
from .harness import DispatchTestCase
@@ -15,6 +16,21 @@ class TestDeconvolutionConverter(DispatchTestCase):
1516
param("non_zero_padding", 1, padding=1),
1617
param("dilation", 1, dilation=2),
1718
param("groups", 1, groups=3),
19+
param("output_padding_1", 3, stride=2, padding=1, output_padding=1),
20+
param("output_padding_2", 3, stride=2, padding=2, output_padding=1),
21+
param("output_padding_3", 3, stride=2, padding=3, output_padding=1),
22+
param("output_padding_4", 3, stride=3, padding=2, output_padding=1),
23+
param("output_padding_5", 3, stride=3, padding=3, output_padding=1),
24+
param("output_padding_6", 3, stride=3, padding=3, output_padding=2),
25+
param(
26+
"combined_params",
27+
3,
28+
stride=3,
29+
padding=3,
30+
dilation=2,
31+
groups=3,
32+
output_padding=2,
33+
),
1834
]
1935
)
2036
def test_deconv1d(
@@ -26,6 +42,7 @@ def test_deconv1d(
2642
dilation=1,
2743
groups=1,
2844
bias=True,
45+
output_padding=0,
2946
):
3047
class TestModule(torch.nn.Module):
3148
def __init__(self):
@@ -36,9 +53,10 @@ def __init__(self):
3653
kernel_size=kernel_size,
3754
stride=stride,
3855
padding=padding,
39-
dilation=dilation,
56+
output_padding=output_padding,
4057
groups=groups,
4158
bias=bias,
59+
dilation=dilation,
4260
)
4361

4462
def forward(self, x):
@@ -101,6 +119,22 @@ def forward(self, x):
101119
param("non_zero_padding", 1, padding=1),
102120
param("dilation", 1, dilation=2),
103121
param("groups", 1, groups=3),
122+
param("output_padding_1", 3, stride=2, padding=1, output_padding=1),
123+
param("output_padding_2", 3, stride=2, padding=1, output_padding=1),
124+
param("output_padding_3", 3, stride=2, padding=2, output_padding=1),
125+
param("output_padding_4", 3, stride=2, padding=3, output_padding=1),
126+
param("output_padding_5", 3, stride=3, padding=2, output_padding=1),
127+
param("output_padding_6", 3, stride=3, padding=3, output_padding=1),
128+
param("output_padding_7", 3, stride=3, padding=3, output_padding=2),
129+
param(
130+
"combined_params",
131+
3,
132+
stride=3,
133+
padding=3,
134+
dilation=2,
135+
groups=3,
136+
output_padding=2,
137+
),
104138
]
105139
)
106140
def test_deconv2d(
@@ -112,6 +146,7 @@ def test_deconv2d(
112146
dilation=1,
113147
groups=1,
114148
bias=True,
149+
output_padding=0,
115150
):
116151
class TestModule(torch.nn.Module):
117152
def __init__(self):
@@ -122,9 +157,10 @@ def __init__(self):
122157
kernel_size=kernel_size,
123158
stride=stride,
124159
padding=padding,
125-
dilation=dilation,
160+
output_padding=output_padding,
126161
groups=groups,
127162
bias=bias,
163+
dilation=dilation,
128164
)
129165

130166
def forward(self, x):
@@ -172,6 +208,19 @@ def forward(self, x):
172208
param("non_zero_padding", 1, padding=1),
173209
param("dilation", 1, dilation=2),
174210
param("groups", 1, groups=3),
211+
param("output_padding_1", 3, stride=2, padding=1, output_padding=1),
212+
param("output_padding_2", 3, stride=2, padding=2, output_padding=1),
213+
param("output_padding_3", 3, stride=3, padding=3, output_padding=1),
214+
param("output_padding_4", 3, stride=3, padding=3, output_padding=2),
215+
param(
216+
"combined_params",
217+
3,
218+
stride=3,
219+
padding=3,
220+
dilation=2,
221+
groups=3,
222+
output_padding=2,
223+
),
175224
]
176225
)
177226
def test_deconv3d(
@@ -183,6 +232,7 @@ def test_deconv3d(
183232
dilation=1,
184233
groups=1,
185234
bias=True,
235+
output_padding=0,
186236
):
187237
class TestModule(torch.nn.Module):
188238
def __init__(self):
@@ -193,9 +243,10 @@ def __init__(self):
193243
kernel_size=kernel_size,
194244
stride=stride,
195245
padding=padding,
196-
dilation=dilation,
246+
output_padding=output_padding,
197247
groups=groups,
198248
bias=bias,
249+
dilation=dilation,
199250
)
200251

201252
def forward(self, x):
@@ -209,8 +260,8 @@ def forward(self, x):
209260
enable_passes=True,
210261
)
211262

212-
# Testing with (-1, -1, -1, -1, -1) results into Error:
213-
# AssertionError: Channel dim can't be dynamic for deconvolution.
263+
# # Testing with (-1, -1, -1, -1, -1) results into Error:
264+
# # AssertionError: Channel dim can't be dynamic for deconvolution.
214265

215266
def test_deconv3d_with_dynamic_shape(self):
216267
class TestModule(torch.nn.Module):

0 commit comments

Comments
 (0)