Skip to content

feat(atenlib): ops 7/n #279

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 14 commits into from
Jan 7, 2023
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 6 additions & 8 deletions onnxscript/function_libs/torch_aten/ops/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -198,20 +198,18 @@ def aten_alpha_dropout(input: TensorType, p: float, train: bool) -> TensorType:
raise NotImplementedError()


def aten_amax(
self: TensorType, dim: Optional[Sequence[int]] = None, keepdim: bool = False
) -> TensorType:
def aten_amax(self: TReal, dim: Optional[INT64["M"]] = None, keepdim: bool = False) -> TReal:
# amax(Tensor self, int[1] dim=[], bool keepdim=False) -> Tensor

raise NotImplementedError()
return op.ReduceMax(self, dim, keepdims=keepdim)


def aten_amin(
self: TensorType, dim: Optional[Sequence[int]] = None, keepdim: bool = False
self: TensorType, dim: Optional[INT64["M"]] = None, keepdim: bool = False
) -> TensorType:
# amin(Tensor self, int[1] dim=[], bool keepdim=False) -> Tensor

raise NotImplementedError()
return op.ReduceMin(self, dim, keepdims=keepdim)


def aten_aminmax(
Expand Down Expand Up @@ -4171,10 +4169,10 @@ def aten_rsqrt(self: TFloatOrBFloat16) -> TFloatOrBFloat16:
return op.Reciprocal(op.Sqrt(self))


def aten_rsub(self: TensorType, other: TensorType, alpha: float = 1) -> TensorType:
def aten_rsub(self: TReal, other: TReal, alpha: float = 1) -> TReal:
# rsub.Tensor(Tensor self, Tensor other, *, Scalar alpha=1) -> Tensor

raise NotImplementedError()
return op.Mul(op.Sub(other, self), alpha)


def aten_scalar_tensor(s: float) -> TensorType:
Expand Down
34 changes: 26 additions & 8 deletions onnxscript/function_libs/torch_aten/ops/nn.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,20 +19,24 @@
from onnxscript import INT64
from onnxscript.function_libs.torch_aten.registration import torch_op
from onnxscript.function_libs.torch_aten.typing import TFloat, TFloatOrBFloat16, TReal
from onnxscript.onnx_opset import opset18 as op
from onnxscript.onnx_opset import opset17 as op
from onnxscript.onnx_types import TensorType


def aten_adaptive_avg_pool2d(self: TensorType, output_size: INT64) -> TensorType:
def aten_adaptive_avg_pool2d(self: TFloat, output_size: INT64[2]) -> TFloat:
# adaptive_avg_pool2d(Tensor self, SymInt[2] output_size) -> Tensor

raise NotImplementedError()
# assert output_size == [1, 1]
# TODO(justinchuby): Specify input constraints
return op.GlobalAveragePool(self)


def aten_adaptive_avg_pool3d(self: TensorType, output_size: INT64) -> TensorType:
# adaptive_avg_pool3d(Tensor self, SymInt[3] output_size) -> Tensor

raise NotImplementedError()
# assert output_size == [1, 1, 1]
# TODO(justinchuby): Specify input constraints
return op.GlobalAveragePool(self)


def aten_adaptive_max_pool2d(
Expand Down Expand Up @@ -1162,15 +1166,29 @@ def aten_upsample_nearest1d_backward(
raise NotImplementedError()


@torch_op("aten::upsample_nearest2d")
def aten_upsample_nearest2d(
self: TensorType,
output_size: INT64,
self: TReal,
size: INT64,
scales_h: Optional[float] = None,
scales_w: Optional[float] = None,
) -> TensorType:
) -> TReal:
# upsample_nearest2d(Tensor self, SymInt[2] output_size, float? scales_h=None, float? scales_w=None) -> Tensor

raise NotImplementedError()
self_shape = op.Shape(self)
batch_channel = self_shape[:2]
output_size = op.Concat(batch_channel, size, axis=0)

# TODO(justinchuby): Conditionally use scales

return op.Resize(
self,
None,
None,
size,
mode="nearest",
coordinate_transformation_mode="asymmetric",
)


def aten_upsample_nearest2d_backward(
Expand Down
59 changes: 58 additions & 1 deletion onnxscript/test/function_libs/torch_aten/ops_correctness_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -145,9 +145,22 @@ def wrapped(fn):

# Modify this section ##########################################################


def _upsample_kwargs_wrangler(kwargs: dict[str, Any]) -> dict[str, Any]:
if "scale_factor" in kwargs:
kwargs["scales_h"] = kwargs["scale_factor"]
kwargs["scales_w"] = kwargs["scale_factor"]
del kwargs["scale_factor"]
if "size" in kwargs:
kwargs["size"] = np.array(kwargs["size"])
return kwargs


# Ops to be tested for numerical consistency between onnx and pytorch
# Find the names of the OpInfos in torch/testing/_internal/common_methods_invocations.py
OPINFO_FUNCTION_MAPPING: dict[str, onnxscript.OnnxFunction] = {
OPINFO_FUNCTION_MAPPING: dict[
str, onnxscript.OnnxFunction | tuple[onnxscript.OnnxFunction, Callable]
] = {
"abs": core_ops.aten_abs,
"acos": core_ops.aten_acos,
"acosh": core_ops.aten_acosh,
Expand Down Expand Up @@ -183,6 +196,10 @@ def wrapped(fn):
"nn.functional.relu": nn_ops.aten_relu,
"nn.functional.relu6": nn_ops.aten_relu6,
"nn.functional.selu": core_ops.aten_selu,
"nn.functional.upsample_nearest2d": (
nn_ops.aten_upsample_nearest2d,
_upsample_kwargs_wrangler,
),
"nonzero": core_ops.aten_nonzero,
"ones_like": core_ops.aten_ones_like,
"ones": core_ops.aten_ones,
Expand Down Expand Up @@ -227,13 +244,44 @@ def wrapped(fn):
matcher=lambda sample: sample.kwargs.get("as_tuple") is True,
reason="as_tuple=True is not supported",
),
skip(
"nn.functional.upsample_nearest2d",
# Shape should be [N, C, H, W]
matcher=lambda sample: len(sample.input.shape) != 2 + 2,
reason="only test on 2d inputs",
),
skip(
"nn.functional.upsample_nearest2d",
matcher=lambda sample: "scale_factor" in sample.kwargs,
reason="fixme: the scale_factor tests",
),
)
OP_WITH_SKIPPED_SUBTESTS = frozenset(meta.op_name for meta in SKIP_SUBTESTS)

# END OF SECTION TO MODIFY #####################################################


def duplicate_opinfo(opinfos: list[opinfo_core.OpInfo], name: str, new_names: tuple[str, ...]):
"""Duplicate an opinfo in the opinfo database and give it a new name."""
for opinfo in opinfos:
if opinfo.name == name:
for new_name in new_names:
new_opinfo = copy.deepcopy(opinfo)
new_opinfo.name = new_name
opinfos.append(new_opinfo)
return


OPS_DB = copy.deepcopy(common_methods_invocations.op_db)
duplicate_opinfo(
OPS_DB,
"nn.functional.upsample_nearest",
(
"nn.functional.upsample_nearest1d",
"nn.functional.upsample_nearest2d",
"nn.functional.upsample_nearest3d",
),
)

ALL_OPS_IN_DB = frozenset(op_info.name for op_info in OPS_DB)
# Assert all ops in OPINFO_FUNCTION_MAPPING are in the OPS_DB
Expand Down Expand Up @@ -332,6 +380,13 @@ def test_output_match(self, device: str, dtype: torch.dtype, op):
)

onnx_function = OPINFO_FUNCTION_MAPPING[op.name]
kwarg_wrangler = None
if isinstance(onnx_function, tuple):
# Obtain the kwarg_wrangler that manipulates the OpInfo inputs
# to match the aten operator signature
# An example is nn.functional.upsample_nearest2d, which has a different signature
# than the aten operator upsample_nearest2d
onnx_function, kwarg_wrangler = onnx_function

for (i, cpu_sample) in enumerate(samples):
inputs = (cpu_sample.input, *cpu_sample.args)
Expand All @@ -346,6 +401,8 @@ def test_output_match(self, device: str, dtype: torch.dtype, op):
self.skipTest(skip_reason)
input_onnx = [_convert_tensor_to_numpy(x) for x in inputs]
kwargs_onnx = _convert_kwargs_for_onnx(cpu_sample.kwargs)
if kwarg_wrangler:
kwargs_onnx = kwarg_wrangler(kwargs_onnx)
output_torch = op(*inputs, **cpu_sample.kwargs)
try:
function_output = onnx_function(*input_onnx, **kwargs_onnx)
Expand Down