diff --git a/onnxscript/function_libs/torch_lib/ops/core.py b/onnxscript/function_libs/torch_lib/ops/core.py index cf9836cd3c..2bdea7ca5f 100644 --- a/onnxscript/function_libs/torch_lib/ops/core.py +++ b/onnxscript/function_libs/torch_lib/ops/core.py @@ -2074,16 +2074,30 @@ def aten_convolution( ) -> TFloat: """convolution(Tensor input, Tensor weight, Tensor? bias, int[] stride, SymInt[] padding, int[] dilation, bool transposed, SymInt[] output_padding, int groups) -> Tensor""" + rank = len(input.shape) + + image_d = rank - 2 + + # NOTE: We assume the sequence padding/dilation/stride + # from ATen op can only be either len == 1 or + # len == rank. + if not isinstance(padding, Sequence): - padding = (padding, padding) + padding = [padding] * image_d + elif len(padding) == 1: + padding = [padding[0]] * image_d pads = [*padding, *padding] if not isinstance(dilation, Sequence): - dilation = (dilation, dilation) + dilation = [dilation] * image_d + elif len(dilation) == 1: + dilation = [dilation[0]] * image_d dilations = list(dilation) if not isinstance(stride, Sequence): - stride = (stride, stride) + stride = [stride] * image_d + elif len(stride) == 1: + stride = [stride[0]] * image_d strides = list(stride) result = _aten_convolution_onnx( diff --git a/tests/function_libs/torch_lib/extra_opinfo.py b/tests/function_libs/torch_lib/extra_opinfo.py index 2fc79a3dd0..70a1e0547f 100644 --- a/tests/function_libs/torch_lib/extra_opinfo.py +++ b/tests/function_libs/torch_lib/extra_opinfo.py @@ -239,6 +239,19 @@ def sample_inputs_convolution(op_info, device, dtype, requires_grad, **kwargs): "groups": 1, }, ), + ( + (1, 3, 224, 224), + (32, 3, 3, 3), + None, + { + "stride": (2,), + "padding": (1,), + "dilation": (1,), + "transposed": False, + "output_padding": (0, 0), + "groups": 1, + }, + ), ( (1, 3, 3, 224, 224), (32, 3, 3, 3, 3), @@ -252,21 +265,19 @@ def sample_inputs_convolution(op_info, device, dtype, requires_grad, **kwargs): "groups": 1, }, ), - # FIXME(jiz): Uncomment out these test data once - # torch 2.0 is released. - # ( - # (1, 3, 224, 224, 224), - # (32, 3, 3, 3, 3), - # (32,), - # { - # "stride": (2, 2, 2), - # "padding": (1, 1, 1), - # "dilation": (1, 1, 1), - # "transposed": False, - # "output_padding": (0, 0, 0), - # "groups": 1, - # }, - # ), + ( + (1, 3, 224, 224, 224), + (32, 3, 3, 3, 3), + (32,), + { + "stride": (2, 2, 2), + "padding": (1, 1, 1), + "dilation": (1, 1, 1), + "transposed": False, + "output_padding": (0, 0, 0), + "groups": 1, + }, + ), ( (2, 4, 6, 6), (4, 1, 3, 3), diff --git a/tests/function_libs/torch_lib/ops_test_data.py b/tests/function_libs/torch_lib/ops_test_data.py index 75da0c0fd0..e3be105839 100644 --- a/tests/function_libs/torch_lib/ops_test_data.py +++ b/tests/function_libs/torch_lib/ops_test_data.py @@ -1557,7 +1557,7 @@ def _where_input_wrangler( TorchLibOpInfo( "ops.aten.convolution", core_ops.aten_convolution, - tolerance={torch.float32: (3.7e-5, 1.8e-4)}, + tolerance={torch.float32: (2e-4, 9e-4)}, ), TorchLibOpInfo("empty_like", core_ops.aten_empty_like, nondeterministic=True), TorchLibOpInfo(