Skip to content

Add several new ops relative to convolution. #431

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 7 commits into from
Feb 13, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
82 changes: 72 additions & 10 deletions onnxscript/function_libs/torch_aten/ops/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -1075,18 +1075,49 @@ def aten_contiguous(self: TensorType, memory_format: str = "contiguous_format")
raise NotImplementedError()


@torch_op("aten::conv1d", trace_only=True)
def aten_conv1d(
input: TensorType,
weight: TensorType,
bias: Optional[TensorType] = None,
input: TFloat,
weight: TFloat,
bias: Optional[TFloat] = None,
stride: Sequence[int] = (1,),
padding: Sequence[int] = (0,),
dilation: Sequence[int] = (1,),
groups: int = 1,
) -> TensorType:
) -> TFloat:
# conv1d(Tensor input, Tensor weight, Tensor? bias=None, int[1] stride=1, int[1] padding=0, int[1] dilation=1, int groups=1) -> Tensor

raise NotImplementedError()
# Attributes need to be manipulated in Python to match ONNX's conv1d
if not isinstance(padding, Sequence):
padding = (padding,)
pads = [*padding, *padding]

if not isinstance(dilation, Sequence):
dilation = (dilation,)
dilations = list(dilation)

if not isinstance(stride, Sequence):
stride = (stride,)
strides = list(stride)

if bias is None:
weight_dim_0 = op.Shape(weight, start=0, end=1)
bias_shape = op.Expand(weight_dim_0, op.Constant(value_ints=[1]))
zero = op.CastLike(0.0, input)
bias = op.Expand(zero, bias_shape)

result = _aten_convolution_onnx(
input,
weight,
bias,
transposed=False,
strides=strides,
pads=pads,
dilations=dilations,
groups=groups,
)

return result


@torch_op("aten::conv2d", trace_only=True)
Expand Down Expand Up @@ -1134,18 +1165,49 @@ def aten_conv2d(
return result


@torch_op("aten::conv3d", trace_only=True)
def aten_conv3d(
input: TensorType,
weight: TensorType,
bias: Optional[TensorType] = None,
input: TFloat,
weight: TFloat,
bias: Optional[TFloat] = None,
stride: Sequence[int] = (1, 1, 1),
padding: Sequence[int] = (0, 0, 0),
dilation: Sequence[int] = (1, 1, 1),
groups: int = 1,
) -> TensorType:
) -> TFloat:
# conv3d(Tensor input, Tensor weight, Tensor? bias=None, int[3] stride=1, int[3] padding=0, int[3] dilation=1, int groups=1) -> Tensor

raise NotImplementedError()
# Attributes need to be manipulated in Python to match ONNX's conv3d
if not isinstance(padding, Sequence):
padding = (padding, padding, padding)
pads = [*padding, *padding]

if not isinstance(dilation, Sequence):
dilation = (dilation, dilation, dilation)
dilations = list(dilation)

if not isinstance(stride, Sequence):
stride = (stride, stride, stride)
strides = list(stride)

if bias is None:
weight_dim_0 = op.Shape(weight, start=0, end=1)
bias_shape = op.Expand(weight_dim_0, op.Constant(value_ints=[1]))
zero = op.CastLike(0.0, input)
bias = op.Expand(zero, bias_shape)

result = _aten_convolution_onnx(
input,
weight,
bias,
transposed=False,
strides=strides,
pads=pads,
dilations=dilations,
groups=groups,
)

return result


def aten_conv_tbc(
Expand Down
73 changes: 73 additions & 0 deletions onnxscript/tests/function_libs/torch_aten/extra_opinfo.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,54 @@
from torch.testing._internal.opinfo import core as opinfo_core


def sample_inputs_conv3d(op_info, device, dtype, requires_grad, **kwargs):
del op_info
make_arg = functools.partial(
torch_testing.make_tensor, device=device, dtype=dtype, requires_grad=requires_grad
)

# Ordered as shapes for input, weight, bias,
# and a dict of values of (stride, padding, dilation, groups)
cases: tuple[tuple[int, ...], tuple[int, ...], tuple[int, ...], dict[str, Any]] = ( # type: ignore[assignment]
(
(1, 3, 3, 224, 224),
(32, 3, 3, 3, 3),
None,
{
"stride": (2, 2, 2),
"padding": (1, 1, 1),
"dilation": (1, 1, 1),
"groups": 1,
},
),
(
(2, 4, 3, 56, 56),
(32, 4, 3, 3, 3),
(32,),
{
"stride": (3, 3, 3),
"padding": 2,
"dilation": (1, 1, 1),
"groups": 1,
},
),
)

for input_shape, weight, bias, kwargs in cases: # type: ignore[assignment]
# Batched
yield opinfo_core.SampleInput(
make_arg(input_shape),
args=(make_arg(weight), make_arg(bias) if bias is not None else bias),
kwargs=kwargs,
)
# Unbatched
yield opinfo_core.SampleInput(
make_arg(input_shape[1:]), # type: ignore[index]
args=(make_arg(weight), make_arg(bias) if bias is not None else bias),
kwargs=kwargs,
)


def sample_inputs_convolution(op_info, device, dtype, requires_grad, **kwargs):
del op_info
make_arg = functools.partial(
Expand Down Expand Up @@ -60,6 +108,19 @@ def sample_inputs_convolution(op_info, device, dtype, requires_grad, **kwargs):
"groups": 1,
},
),
(
(1, 3, 3, 224, 224),
(32, 3, 3, 3, 3),
(32,),
{
"stride": (2, 2, 2),
"padding": (1, 1, 1),
"dilation": (1, 1, 1),
"transposed": False,
"output_padding": (0, 0, 0),
"groups": 1,
},
),
# FIXME(jiz): Uncomment out these test data once
# torch 2.0 is released.
# (
Expand Down Expand Up @@ -111,4 +172,16 @@ def sample_inputs_convolution(op_info, device, dtype, requires_grad, **kwargs):
skips=(),
supports_out=False,
),
opinfo_core.OpInfo(
"nn.functional.conv3d",
aliases=("conv3d",),
aten_name="conv3d",
dtypes=common_dtype.floating_and_complex_types_and(torch.int64, torch.bfloat16),
sample_inputs_func=sample_inputs_conv3d,
supports_forward_ad=True,
supports_fwgrad_bwgrad=True,
gradcheck_nondet_tol=common_utils.GRADCHECK_NONDET_TOL,
skips=(),
supports_out=False,
),
]
Original file line number Diff line number Diff line change
Expand Up @@ -367,7 +367,9 @@ def _where_input_wrangler(
"empty_like": core_ops.aten_empty_like,
"index_select": core_ops.aten_index_select,
"native_layer_norm": core_ops.aten_native_layer_norm,
"nn.functional.conv1d": core_ops.aten_conv1d,
"nn.functional.conv2d": core_ops.aten_conv2d,
"nn.functional.conv3d": core_ops.aten_conv3d,
"nn.functional.gelu": nn_ops.aten_gelu,
"nn.functional.linear": nn_ops.aten_linear,
"ones_like": core_ops.aten_ones_like,
Expand Down Expand Up @@ -461,6 +463,11 @@ def _where_input_wrangler(
matcher=lambda sample: sample.args[0] != (1, 1, 1),
reason="only global pooling is supported; only batched inputs are supported",
),
skip(
"nn.functional.conv1d",
matcher=lambda sample: isinstance(sample.kwargs.get("padding"), str),
reason="String padding is not accepted by aten::conv1d",
),
skip(
"nn.functional.conv2d",
matcher=lambda sample: isinstance(sample.kwargs.get("padding"), str),
Expand Down