Skip to content

feat(atenlib): ops 7/n #279

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 14 commits into from
Jan 7, 2023
Merged
Show file tree
Hide file tree
Changes from 11 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
35 changes: 11 additions & 24 deletions onnxscript/function_libs/torch_aten/ops/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,20 +48,6 @@ def aten_acosh(self: TFloat) -> TFloat:
return op.Acosh(self)


def aten_adaptive_avg_pool1d(self: TensorType, output_size: Sequence[int]) -> TensorType:
# adaptive_avg_pool1d(Tensor self, int[1] output_size) -> Tensor

raise NotImplementedError()


def aten_adaptive_max_pool1d(
self: TensorType, output_size: Sequence[int]
) -> tuple[TensorType, TensorType]:
# adaptive_max_pool1d(Tensor self, int[1] output_size) -> (Tensor, Tensor)

raise NotImplementedError()


@torch_op("aten::add")
def aten_add(self: TReal, other: TReal, alpha: float = 1) -> TReal:
# add.Tensor(Tensor self, Tensor other, *, Scalar alpha=1) -> Tensor
Expand Down Expand Up @@ -198,20 +184,20 @@ def aten_alpha_dropout(input: TensorType, p: float, train: bool) -> TensorType:
raise NotImplementedError()


def aten_amax(
self: TensorType, dim: Optional[Sequence[int]] = None, keepdim: bool = False
) -> TensorType:
# @torch_op("aten::amax") # FIXME: Uncomment when CI uses onnx 1.13
def aten_amax(self: TReal, dim: INT64, keepdim: int = 0) -> TReal:
# amax(Tensor self, int[1] dim=[], bool keepdim=False) -> Tensor

raise NotImplementedError()
# TODO(justinchuby): Make dim optional, keepdim bool
return op.ReduceMax(self, dim, keepdims=keepdim)


def aten_amin(
self: TensorType, dim: Optional[Sequence[int]] = None, keepdim: bool = False
) -> TensorType:
# @torch_op("aten::amin") # FIXME: Uncomment when CI uses onnx 1.13
def aten_amin(self: TReal, dim: INT64, keepdim: int = 0) -> TReal:
# amin(Tensor self, int[1] dim=[], bool keepdim=False) -> Tensor

raise NotImplementedError()
# TODO(justinchuby): Make dim optional, keepdim bool
return op.ReduceMin(self, dim, keepdims=keepdim)


def aten_aminmax(
Expand Down Expand Up @@ -4181,10 +4167,11 @@ def aten_rsqrt(self: TFloatOrBFloat16) -> TFloatOrBFloat16:
return op.Reciprocal(op.Sqrt(self))


def aten_rsub(self: TensorType, other: TensorType, alpha: float = 1) -> TensorType:
@torch_op("aten::rsub")
def aten_rsub(self: TReal, other: TReal, alpha: float = 1.0) -> TReal:
# rsub.Tensor(Tensor self, Tensor other, *, Scalar alpha=1) -> Tensor

raise NotImplementedError()
return op.Sub(other, op.Mul(self, alpha))


def aten_scalar_tensor(s: float) -> TensorType:
Expand Down
51 changes: 44 additions & 7 deletions onnxscript/function_libs/torch_aten/ops/nn.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,15 +23,38 @@
from onnxscript.onnx_types import TensorType


def aten_adaptive_avg_pool2d(self: TensorType, output_size: INT64) -> TensorType:
@torch_op("aten::aten_adaptive_avg_pool1d")
def aten_adaptive_avg_pool1d(self: TFloat, output_size: INT64[1]) -> TFloat:
# adaptive_avg_pool1d(Tensor self, int[1] output_size) -> Tensor

# assert output_size == [1]
# TODO(justinchuby): Specify input constraints
return op.GlobalAveragePool(self)


@torch_op("aten::aten_adaptive_avg_pool2d")
def aten_adaptive_avg_pool2d(self: TFloat, output_size: INT64[2]) -> TFloat:
# adaptive_avg_pool2d(Tensor self, SymInt[2] output_size) -> Tensor

raise NotImplementedError()
# assert output_size == [1, 1]
# TODO(justinchuby): Specify input constraints
return op.GlobalAveragePool(self)


def aten_adaptive_avg_pool3d(self: TensorType, output_size: INT64) -> TensorType:
@torch_op("aten::aten_adaptive_avg_pool3d")
def aten_adaptive_avg_pool3d(self: TFloat, output_size: INT64[3]) -> TFloat:
# adaptive_avg_pool3d(Tensor self, SymInt[3] output_size) -> Tensor

# assert output_size == [1, 1, 1]
# TODO(justinchuby): Specify input constraints
return op.GlobalAveragePool(self)


def aten_adaptive_max_pool1d(
self: TensorType, output_size: Sequence[int]
) -> tuple[TensorType, TensorType]:
# adaptive_max_pool1d(Tensor self, int[1] output_size) -> (Tensor, Tensor)

raise NotImplementedError()


Expand Down Expand Up @@ -1162,15 +1185,29 @@ def aten_upsample_nearest1d_backward(
raise NotImplementedError()


@torch_op("aten::upsample_nearest2d")
def aten_upsample_nearest2d(
self: TensorType,
output_size: INT64,
self: TReal,
size: INT64,
scales_h: Optional[float] = None,
scales_w: Optional[float] = None,
) -> TensorType:
) -> TReal:
# upsample_nearest2d(Tensor self, SymInt[2] output_size, float? scales_h=None, float? scales_w=None) -> Tensor

raise NotImplementedError()
self_shape = op.Shape(self)
batch_channel = self_shape[:2] # type: ignore[index]
output_size = op.Concat(batch_channel, size, axis=0)

# TODO(justinchuby): Conditionally use scales

return op.Resize(
self,
None,
None,
output_size,
mode="nearest",
coordinate_transformation_mode="asymmetric",
)


def aten_upsample_nearest2d_backward(
Expand Down
128 changes: 109 additions & 19 deletions onnxscript/test/function_libs/torch_aten/ops_correctness_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,11 +4,11 @@
import copy
import dataclasses
import unittest
import warnings
from typing import Any, Callable, Collection, Iterable, Optional, Sequence, TypeVar

import numpy as np
import onnx
import onnxruntime.capi.onnxruntime_pybind11_state
import torch
from torch.testing._internal import common_device_type, common_methods_invocations
from torch.testing._internal.opinfo import core as opinfo_core
Expand Down Expand Up @@ -143,16 +143,52 @@ def wrapped(fn):
return wrapped


def duplicate_opinfo(opinfos: list[opinfo_core.OpInfo], name: str, new_names: tuple[str, ...]):
"""Duplicate an opinfo in the opinfo database and give it a new name."""
duplicated = []
for opinfo in opinfos:
if opinfo.name == name:
for new_name in new_names:
new_opinfo = copy.deepcopy(opinfo)
new_opinfo.name = new_name
duplicated.append(new_opinfo)
opinfos.extend(duplicated)


# Create a copy of the op_db to modify
OPS_DB = copy.deepcopy(common_methods_invocations.op_db)

# Modify this section ##########################################################


def _amax_amin_kwargs_wrangler(kwargs: dict[str, Any]) -> dict[str, Any]:
if "dim" not in kwargs:
kwargs["dim"] = None
return kwargs


def _upsample_kwargs_wrangler(kwargs: dict[str, Any]) -> dict[str, Any]:
if "scale_factor" in kwargs:
kwargs["scales_h"] = kwargs["scale_factor"]
kwargs["scales_w"] = kwargs["scale_factor"]
del kwargs["scale_factor"]
if "size" in kwargs:
kwargs["size"] = np.array(kwargs["size"])
return kwargs


# Ops to be tested for numerical consistency between onnx and pytorch
# Find the names of the OpInfos in torch/testing/_internal/common_methods_invocations.py
OPINFO_FUNCTION_MAPPING: dict[str, onnxscript.OnnxFunction] = {
OPINFO_FUNCTION_MAPPING: dict[
str, onnxscript.OnnxFunction | tuple[onnxscript.OnnxFunction, Callable]
] = {
"abs": core_ops.aten_abs,
"acos": core_ops.aten_acos,
"acosh": core_ops.aten_acosh,
"add": core_ops.aten_add,
"addmm": core_ops.aten_addmm,
"amax": (core_ops.aten_amax, _amax_amin_kwargs_wrangler),
"amin": (core_ops.aten_amin, _amax_amin_kwargs_wrangler),
"asin": core_ops.aten_asin,
"asinh": core_ops.aten_asinh,
"atan": core_ops.aten_atan,
Expand Down Expand Up @@ -180,12 +216,19 @@ def wrapped(fn):
"ne": core_ops.aten_ne,
"neg": core_ops.aten_neg,
"new_full": core_ops.aten_new_full,
"nn.functional.adaptive_avg_pool1d": nn_ops.aten_adaptive_avg_pool1d,
"nn.functional.adaptive_avg_pool2d": nn_ops.aten_adaptive_avg_pool2d,
"nn.functional.adaptive_avg_pool3d": nn_ops.aten_adaptive_avg_pool3d,
"nn.functional.elu": nn_ops.aten_elu,
"nn.functional.leaky_relu": nn_ops.aten_leaky_relu,
"nn.functional.linear": nn_ops.aten_linear,
"nn.functional.relu": nn_ops.aten_relu,
"nn.functional.relu6": nn_ops.aten_relu6,
"nn.functional.selu": core_ops.aten_selu,
"nn.functional.upsample_nearest2d": (
nn_ops.aten_upsample_nearest2d,
_upsample_kwargs_wrangler,
),
"nonzero": core_ops.aten_nonzero,
"ones_like": core_ops.aten_ones_like,
"ones": core_ops.aten_ones,
Expand All @@ -194,6 +237,7 @@ def wrapped(fn):
"repeat": core_ops.aten_repeat,
"round": core_ops.aten_round,
"rsqrt": core_ops.aten_rsqrt,
"rsub": core_ops.aten_rsub,
"sigmoid": core_ops.aten_sigmoid,
"sign": core_ops.aten_sign,
"sin": core_ops.aten_sin,
Expand All @@ -213,11 +257,17 @@ def wrapped(fn):
TESTED_OPS = frozenset(OPINFO_FUNCTION_MAPPING)

EXPECTED_SKIPS_OR_FAILS = (
skip("clamp", reason="Enable when onnxscript errors are fixed"),
xfail("amax", reason="ONNX Runtime 1.13 does not support ReduceMax-18"),
xfail("amin", reason="ONNX Runtime 1.13 does not support ReduceMin-18"),
skip("clamp", reason="Enable when onnxscript supports optional inputs"),
xfail(
"nn.functional.linear",
reason="ONNX Runtime thinks the graph is invalid",
),
xfail(
"nn.functional.upsample_nearest2d",
reason="enable when ONNX Runtime does support opset18",
),
xfail("round", variant_name="decimals_0", reason="The op does not support decimals"),
xfail("round", variant_name="decimals_3", reason="The op does not support decimals"),
xfail("round", variant_name="decimals_neg_3", reason="The op does not support decimals"),
Expand All @@ -228,17 +278,53 @@ def wrapped(fn):
SKIP_SUBTESTS: tuple[DecorateMeta, ...] = (
skip(
"nonzero",
matcher=lambda sample: sample.kwargs.get("as_tuple") is True,
matcher=lambda sample: sample.kwargs.get("as_tuple") is not None,
reason="as_tuple=True is not supported",
),
skip(
"nn.functional.adaptive_avg_pool1d",
# Shape should be [N, C, D1]
matcher=lambda sample: sample.args[0] not in {1, (1,)} or len(sample.input.shape) != 3,
reason="only global pooling is supported; only batched inputs are supported",
),
skip(
"nn.functional.adaptive_avg_pool2d",
matcher=lambda sample: sample.args[0] != (1, 1) or len(sample.input.shape) != 4,
reason="only global pooling is supported; only batched inputs are supported",
),
skip(
"nn.functional.adaptive_avg_pool3d",
matcher=lambda sample: sample.args[0] != (1, 1, 1) or len(sample.input.shape) != 5,
reason="only global pooling is supported; only batched inputs are supported",
),
skip(
"nn.functional.upsample_nearest2d",
# Shape should be [N, C, H, W]
matcher=lambda sample: len(sample.input.shape) != 2 + 2,
reason="only test on 2d inputs",
),
skip(
"nn.functional.upsample_nearest2d",
matcher=lambda sample: "scale_factor" in sample.kwargs,
reason="fixme: the scale_factor tests",
),
)
OP_WITH_SKIPPED_SUBTESTS = frozenset(meta.op_name for meta in SKIP_SUBTESTS)

# END OF SECTION TO MODIFY #####################################################
duplicate_opinfo(
OPS_DB,
"nn.functional.upsample_nearest",
(
"nn.functional.upsample_nearest1d",
"nn.functional.upsample_nearest2d",
"nn.functional.upsample_nearest3d",
),
)


OPS_DB = copy.deepcopy(common_methods_invocations.op_db)
# END OF SECTION TO MODIFY #####################################################


OP_WITH_SKIPPED_SUBTESTS = frozenset(meta.op_name for meta in SKIP_SUBTESTS)
ALL_OPS_IN_DB = frozenset(op_info.name for op_info in OPS_DB)
# Assert all ops in OPINFO_FUNCTION_MAPPING are in the OPS_DB
assert TESTED_OPS.issubset(ALL_OPS_IN_DB), f"{TESTED_OPS - ALL_OPS_IN_DB} not in OPS_DB"
Expand Down Expand Up @@ -336,6 +422,13 @@ def test_output_match(self, device: str, dtype: torch.dtype, op):
)

onnx_function = OPINFO_FUNCTION_MAPPING[op.name]
kwarg_wrangler = None
if isinstance(onnx_function, tuple):
# Obtain the kwarg_wrangler that manipulates the OpInfo inputs
# to match the aten operator signature
# An example is nn.functional.upsample_nearest2d, which has a different signature
# than the aten operator upsample_nearest2d
onnx_function, kwarg_wrangler = onnx_function

for (i, cpu_sample) in enumerate(samples):
inputs = (cpu_sample.input, *cpu_sample.args)
Expand All @@ -347,18 +440,15 @@ def test_output_match(self, device: str, dtype: torch.dtype, op):
):
skip_reason = _should_skip_test_sample(op.name, cpu_sample)
if skip_reason is not None:
self.skipTest(skip_reason)
# Cannot use self.skip because pytest would skip the entire test
warnings.warn(f"skipped sample {i}. Reason: {skip_reason}")
continue
input_onnx = [_convert_tensor_to_numpy(x) for x in inputs]
kwargs_onnx = _convert_kwargs_for_onnx(cpu_sample.kwargs)
output_torch = op(*inputs, **cpu_sample.kwargs)
try:
function_output = onnx_function(*input_onnx, **kwargs_onnx)
# pylint: disable=c-extension-no-member
except onnxruntime.capi.onnxruntime_pybind11_state.NotImplemented:
self.skipTest(
f"ONNX Runtime doesn't support running {op.name} with dtype {dtype}",
)
# pylint: enable=c-extension-no-member
if kwarg_wrangler:
kwargs_onnx = kwarg_wrangler(kwargs_onnx)
torch_output = op(*inputs, **cpu_sample.kwargs)
function_output = onnx_function(*input_onnx, **kwargs_onnx)

if dtype == torch.float32:
# Relax atol and rtol for float32 based on empirical results
Expand All @@ -369,10 +459,10 @@ def test_output_match(self, device: str, dtype: torch.dtype, op):
rtol = None
atol = None

# Use torch testing to ensure dtypes and shapes match
# Use torch.testing as opposed to np.testing to ensure dtypes and shapes match
torch.testing.assert_close(
torch.tensor(function_output),
output_torch,
torch.tensor(torch_output),
rtol=rtol,
atol=atol,
)
Expand Down