Skip to content

Fix Op (convolution) | add nd support to convolution #2108

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 17 additions & 3 deletions onnxscript/function_libs/torch_lib/ops/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -2074,16 +2074,30 @@
) -> TFloat:
"""convolution(Tensor input, Tensor weight, Tensor? bias, int[] stride, SymInt[] padding, int[] dilation, bool transposed, SymInt[] output_padding, int groups) -> Tensor"""

rank = len(input.shape)

image_d = rank - 2

# NOTE: We assume the sequence padding/dilation/stride
# from ATen op can only be either len == 1 or
# len == rank.

if not isinstance(padding, Sequence):
padding = (padding, padding)
padding = [padding] * image_d

Check warning on line 2086 in onnxscript/function_libs/torch_lib/ops/core.py

View check run for this annotation

Codecov / codecov/patch

onnxscript/function_libs/torch_lib/ops/core.py#L2086

Added line #L2086 was not covered by tests
elif len(padding) == 1:
padding = [padding[0]] * image_d
pads = [*padding, *padding]

if not isinstance(dilation, Sequence):
dilation = (dilation, dilation)
dilation = [dilation] * image_d

Check warning on line 2092 in onnxscript/function_libs/torch_lib/ops/core.py

View check run for this annotation

Codecov / codecov/patch

onnxscript/function_libs/torch_lib/ops/core.py#L2092

Added line #L2092 was not covered by tests
elif len(dilation) == 1:
dilation = [dilation[0]] * image_d
dilations = list(dilation)

if not isinstance(stride, Sequence):
stride = (stride, stride)
stride = [stride] * image_d

Check warning on line 2098 in onnxscript/function_libs/torch_lib/ops/core.py

View check run for this annotation

Codecov / codecov/patch

onnxscript/function_libs/torch_lib/ops/core.py#L2098

Added line #L2098 was not covered by tests
elif len(stride) == 1:
stride = [stride[0]] * image_d
strides = list(stride)

result = _aten_convolution_onnx(
Expand Down
41 changes: 26 additions & 15 deletions tests/function_libs/torch_lib/extra_opinfo.py
Original file line number Diff line number Diff line change
Expand Up @@ -239,6 +239,19 @@ def sample_inputs_convolution(op_info, device, dtype, requires_grad, **kwargs):
"groups": 1,
},
),
(
(1, 3, 224, 224),
(32, 3, 3, 3),
None,
{
"stride": (2,),
"padding": (1,),
"dilation": (1,),
"transposed": False,
"output_padding": (0, 0),
"groups": 1,
},
),
(
(1, 3, 3, 224, 224),
(32, 3, 3, 3, 3),
Expand All @@ -252,21 +265,19 @@ def sample_inputs_convolution(op_info, device, dtype, requires_grad, **kwargs):
"groups": 1,
},
),
# FIXME(jiz): Uncomment out these test data once
# torch 2.0 is released.
# (
# (1, 3, 224, 224, 224),
# (32, 3, 3, 3, 3),
# (32,),
# {
# "stride": (2, 2, 2),
# "padding": (1, 1, 1),
# "dilation": (1, 1, 1),
# "transposed": False,
# "output_padding": (0, 0, 0),
# "groups": 1,
# },
# ),
(
(1, 3, 224, 224, 224),
(32, 3, 3, 3, 3),
(32,),
{
"stride": (2, 2, 2),
"padding": (1, 1, 1),
"dilation": (1, 1, 1),
"transposed": False,
"output_padding": (0, 0, 0),
"groups": 1,
},
),
(
(2, 4, 6, 6),
(4, 1, 3, 3),
Expand Down
2 changes: 1 addition & 1 deletion tests/function_libs/torch_lib/ops_test_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -1557,7 +1557,7 @@ def _where_input_wrangler(
TorchLibOpInfo(
"ops.aten.convolution",
core_ops.aten_convolution,
tolerance={torch.float32: (3.7e-5, 1.8e-4)},
tolerance={torch.float32: (2e-4, 9e-4)},
),
TorchLibOpInfo("empty_like", core_ops.aten_empty_like, nondeterministic=True),
TorchLibOpInfo(
Expand Down
Loading