Skip to content

feat(atenlib): implement aten functions 1/n #247

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 23 commits into from
Dec 13, 2022
Merged
Show file tree
Hide file tree
Changes from 5 commits
Commits
Show all changes
23 commits
Select commit Hold shift + click to select a range
55f8dd9
fix: annotate script()
justinchuby Dec 9, 2022
8a3a587
feat(atenlib): clamp, lt, gt
justinchuby Dec 9, 2022
f821b6a
Update on "feat(atenlib): implement aten functions 1/n"
justinchuby Dec 9, 2022
aecc148
Update on "feat(atenlib): implement aten functions 1/n"
justinchuby Dec 10, 2022
6555a55
Update on "feat(atenlib): implement aten functions 1/n"
justinchuby Dec 10, 2022
f8385b0
Update base for Update on "feat(atenlib): implement aten functions 1/n"
justinchuby Dec 10, 2022
468f86f
Update on "feat(atenlib): implement aten functions 1/n"
justinchuby Dec 10, 2022
47b8380
Update base for Update on "feat(atenlib): implement aten functions 1/n"
justinchuby Dec 10, 2022
00f1760
Update on "feat(atenlib): implement aten functions 1/n"
justinchuby Dec 10, 2022
060f9db
Update base for Update on "feat(atenlib): implement aten functions 1/n"
justinchuby Dec 10, 2022
497cb16
Update on "feat(atenlib): implement aten functions 1/n"
justinchuby Dec 10, 2022
9bb4038
Update base for Update on "feat(atenlib): implement aten functions 1/n"
justinchuby Dec 12, 2022
d24110a
Update on "feat(atenlib): implement aten functions 1/n"
justinchuby Dec 12, 2022
cbfb867
Update base for Update on "feat(atenlib): implement aten functions 1/n"
justinchuby Dec 12, 2022
875f235
Update on "feat(atenlib): implement aten functions 1/n"
justinchuby Dec 12, 2022
27008e1
Update base for Update on "feat(atenlib): implement aten functions 1/n"
justinchuby Dec 12, 2022
3a8737d
Update on "feat(atenlib): implement aten functions 1/n"
justinchuby Dec 12, 2022
c5871c8
Update base for Update on "feat(atenlib): implement aten functions 1/n"
justinchuby Dec 12, 2022
012905c
Update on "feat(atenlib): implement aten functions 1/n"
justinchuby Dec 12, 2022
49be5ec
Update base for Update on "feat(atenlib): implement aten functions 1/n"
justinchuby Dec 13, 2022
3a9c5f6
Update on "feat(atenlib): implement aten functions 1/n"
justinchuby Dec 13, 2022
d4f09e8
Update base for Update on "feat(atenlib): implement aten functions 1/n"
justinchuby Dec 13, 2022
ee3143e
Update on "feat(atenlib): implement aten functions 1/n"
justinchuby Dec 13, 2022
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
65 changes: 50 additions & 15 deletions onnxscript/function_libs/torch_aten/ops/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,10 @@

from typing import Any, Optional, Sequence

import onnx.helper

from onnxscript import BOOL, INT64
from onnxscript.onnx_opset import default_opset as op
from onnxscript.onnx_opset import opset18 as op
from onnxscript.onnx_types import TensorType


Expand Down Expand Up @@ -747,16 +749,31 @@ def aten_clamp(
raise NotImplementedError()


def aten_clamp_max(self: TensorType, max: float) -> TensorType:
def aten_clamp_max_scalar(self, max_):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is it possible to validate the shape of max_ to know it is a scalar or tensor instead of having such information in the function name?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I haven't thought of a good way. Any suggestions?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah, for validation, we can specify it in the signature / via some data structure. We can discuss today.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

They still need to be two functions though.

# clamp_max(Tensor self, Scalar max) -> Tensor

raise NotImplementedError()
max_ = op.CastLike(max_, self)
return op.Clip(self, None, max_)


def aten_clamp_max_tensor(self, max_):
# clamp_max(Tensor self, Scalar max) -> Tensor

return op.Min(self, max_)


def aten_clamp_min(self: TensorType, min: float) -> TensorType:
def aten_clamp_min_scalar(self, min_):
# clamp_min(Tensor self, Scalar min) -> Tensor
# NOTE: min_ is a rank 0 tensor.
# TODO(justinchuby): Specify the type constraints.
min_ = op.CastLike(min_, self)
return op.Clip(self, min_, None)

raise NotImplementedError()

def aten_clamp_min_tensor(self, min_):
# clamp_min(Tensor self, Tensor min) -> Tensor
# TODO(justinchuby): Specify the type constraints.
return op.Max(self, min_)


def aten_clip(
Expand Down Expand Up @@ -1958,10 +1975,12 @@ def aten_gru_cell(
raise NotImplementedError()


def aten_gt(self: TensorType, other: TensorType) -> TensorType:
def aten_gt(self, other):
# gt.Tensor(Tensor self, Tensor other) -> Tensor

raise NotImplementedError()
# TODO(justinchuby): Input spec: non bool tensor
# Boolean inputs can be pre-casted by policy
return op.Greater(self, other)


def aten_hamming_window(window_length: int) -> TensorType:
Expand Down Expand Up @@ -2572,10 +2591,12 @@ def aten_lstm_mps_backward(
raise NotImplementedError()


def aten_lt(self: TensorType, other: TensorType) -> TensorType:
def aten_lt(self, other):
# lt.Tensor(Tensor self, Tensor other) -> Tensor

raise NotImplementedError()
# TODO(justinchuby): Input spec: non bool tensor
# Boolean inputs can be pre-casted by policy
return op.Less(self, other)


def aten_lu_solve(self: TensorType, LU_data: TensorType, LU_pivots: TensorType) -> TensorType:
Expand Down Expand Up @@ -3440,10 +3461,15 @@ def aten_ones(size: INT64) -> TensorType:
raise NotImplementedError()


def aten_ones_like(self: TensorType, memory_format: Optional[str] = None) -> TensorType:
def aten_ones_like(self, dtype: int = onnx.TensorProto.FLOAT):
"""ones_like.

Note: dtype is an onnx enum. Users should convert torch dtype to onnx dtype
before calling this function.
"""
# ones_like(Tensor self, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None, MemoryFormat? memory_format=None) -> Tensor

raise NotImplementedError()
return common.ones_like(self, dtype)


def aten_or(self: TensorType, other: TensorType) -> TensorType:
Expand Down Expand Up @@ -3916,10 +3942,19 @@ def aten_renorm(self: TensorType, p: float, dim: int, maxnorm: float) -> TensorT
raise NotImplementedError()


def aten_repeat(self: TensorType, repeats: INT64) -> TensorType:
def aten_repeat(self, repeats: INT64["M"]):
# repeat(Tensor self, SymInt[] repeats) -> Tensor

raise NotImplementedError()
# FIXME(justinchuby): When repeats.shape == [0]
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is an example where we need shape dependent logic.


# TODO(justinchuby): Make ones_like a function when onnxscript supports it
# shape = ones_like(repeats) := {
one = op.Constant(value_int=1)
repeats_shape = op.Shape(repeats)
shape = op.Expand(one, repeats_shape)
# }
self_expanded = op.Expand(self, shape)
return op.Tile(self_expanded, repeats)


def aten_repeat_interleave(
Expand Down Expand Up @@ -4012,10 +4047,10 @@ def aten_rot90(self: TensorType, k: int = 1, dims: Sequence[int] = (0, 1)) -> Te
raise NotImplementedError()


def aten_round(self: TensorType) -> TensorType:
def aten_round(self):
# round(Tensor self) -> Tensor

raise NotImplementedError()
return op.Round(self)


def aten_row_indices(self: TensorType) -> TensorType:
Expand Down
5 changes: 4 additions & 1 deletion onnxscript/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
import inspect
import sys
import textwrap
from typing import Any, Callable, Optional

import onnx.helper

Expand Down Expand Up @@ -52,7 +53,9 @@ def script_check(f: ast.FunctionDef, opset, global_names, source, default_opset=
return convert.top_level_stmt(f)


def script(opset=None, default_opset=None, **kwargs):
def script(
opset: Optional[values.Opset] = None, default_opset: Optional[Any] = None, **kwargs: Any
) -> Callable[[Callable[..., Any]], onnxscript.OnnxFunction]:
"""Main decorator. Declares a function as an onnx function.

Args:
Expand Down
93 changes: 85 additions & 8 deletions onnxscript/test/function_libs/torch_aten/ops_correctness_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@

import numpy as np
import onnxruntime.capi.onnxruntime_pybind11_state
import parameterized
import torch
from torch.testing._internal import common_device_type, common_methods_invocations
from torch.testing._internal.opinfo import core as opinfo_core
Expand Down Expand Up @@ -69,14 +70,15 @@ class DecorateMeta:
decorator: Callable[..., Any]
dtypes: Optional[Collection[torch.dtype]]
reason: str
matcher: Optional[Callable[[Any], bool]] = None


def xfail(
op_name: str,
variant_name: str = "",
*,
reason: str,
dtypes: Optional[Collection[torch.dtype]] = None,
reason: Optional[str] = None,
):
"""Expects an OpInfo test to fail.

Expand All @@ -86,8 +88,6 @@ def xfail(
dtypes: The dtypes to expect the failure.
reason: The reason for the failure.
"""
if reason is None:
raise ValueError("Please specify a reason.")
return DecorateMeta(
op_name=op_name,
variant_name=variant_name,
Expand All @@ -101,8 +101,9 @@ def skip(
op_name: str,
variant_name: str = "",
*,
reason: str,
dtypes: Optional[Collection[torch.dtype]] = None,
reason: Optional[str] = None,
matcher: Optional[Callable[[Any], Any]] = None,
):
"""Skips an OpInfo test.

Expand All @@ -111,15 +112,16 @@ def skip(
variant_name: Optional OpInfo variant_test_name.
dtypes: The dtypes to skip.
reason: The reason for skipping.
matcher: A function that matches the test sample input. It is used only when
xfail is in the SKIP_SUBTESTS list.
"""
if reason is None:
raise ValueError("Please specify a reason.")
return DecorateMeta(
op_name=op_name,
variant_name=variant_name,
decorator=unittest.skip(f"Don't care: {reason}"),
dtypes=dtypes,
reason=reason,
matcher=matcher,
)


Expand Down Expand Up @@ -156,19 +158,26 @@ def wrapped(fn):
# Modify this section ##########################################################

# Ops to be tested for numerical consistency between onnx and pytorch
OPINFO_FUNCTION_MAPPING = {
# Find the names of the OpInfos in torch/testing/_internal/common_methods_invocations.py
OPINFO_FUNCTION_MAPPING: dict[str, Callable[..., Any]] = {
"add": core_ops.aten_add,
"gt": core_ops.aten_gt,
"lt": core_ops.aten_lt,
"mul": core_ops.aten_mul,
"nn.functional.elu": nn_ops.aten_elu,
"nn.functional.relu6": nn_ops.aten_relu6,
"nn.functional.selu": core_ops.aten_selu,
"repeat": core_ops.aten_repeat,
"round": core_ops.aten_round,
"sub": core_ops.aten_sub,
}

TESTED_OPS = frozenset(OPINFO_FUNCTION_MAPPING)

EXPECTED_SKIPS_OR_FAILS = (
xfail("add", dtypes=BOOL_TYPES, reason="Add is not defined on bool tensors"),
xfail("gt", dtypes=BOOL_TYPES, reason="Greater is not defined on bool tensors"),
xfail("lt", dtypes=BOOL_TYPES, reason="Less is not defined on bool tensors"),
xfail("mul", dtypes=BOOL_TYPES, reason="Mul is not defined on bool tensors"),
xfail(
"nn.functional.elu",
Expand All @@ -185,14 +194,79 @@ def wrapped(fn):
dtypes=dtypes_except(torch.float16, torch.float32),
reason="ONNX Runtime doesn't support float64 for Selu",
),
# xfail("repeat", reason="fails when repeats is empty."),
xfail(
"round",
variant_name="",
dtypes=dtypes_except(*FLOAT_TYPES),
reason="Round is not defined on non-float tensors",
),
xfail("round", variant_name="decimals_0", reason="The ATen op does not support decimals"),
xfail("round", variant_name="decimals_3", reason="The ATen op does not support decimals"),
xfail(
"round", variant_name="decimals_neg_3", reason="The ATen op does not support decimals"
),
xfail("sub", dtypes=BOOL_TYPES, reason="Sub is not defined on bool tensors"),
)


SKIP_SUBTESTS = (
skip(
"repeat",
reason="repeating when input is a scalar and repeats is empty is not supported.",
matcher=lambda sample: sample.args[0] == (),
),
)
OP_WITH_SKIPPED_SUBTESTS = frozenset(
meta.op_name for meta in SKIP_SUBTESTS
)

# END OF SECTION TO MODIFY #####################################################


OPS_DB = copy.deepcopy(common_methods_invocations.op_db)


class TestFunctionsCompilation(unittest.TestCase):
"""Test all functions can be compiled."""

@parameterized.parameterized.expand(
list(OPINFO_FUNCTION_MAPPING.items()),
)
def test_function_compiles(self, _, function):
compiled = onnxscript.script()(function)
compiled.to_function_proto()


def _convert_tensor_to_numpy(input: Any) -> Any:
if isinstance(input, torch.Tensor):
return input.detach().cpu().numpy()
if isinstance(input, (tuple, list)):
if len(input) == 0:
return np.array((), dtype=np.int64)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why is not the default value float32?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I found that it’s usually an index tuple so it needs to be int. If there are more subtle cases we can update the logic to handle those

if isinstance(input[0], torch.Tensor):
return [_convert_tensor_to_numpy(x) for x in input]
if isinstance(input[0], (int, float)):
# Just a tuple of numbers
return np.array(input)
return input

return input


def _should_skip_test_sample(op_name: str, sample) -> Optional[str]:
"""Returns a reason if a test sample should be skipped."""
if op_name not in OP_WITH_SKIPPED_SUBTESTS:
return None
for decorator_meta in SKIP_SUBTESTS:
# Linear search on SKIP_SUBTESTS. That's fine because the list is small.
if decorator_meta.op_name == op_name:
assert decorator_meta.matcher is not None, "Matcher must be defined"
if decorator_meta.matcher(sample):
return decorator_meta.reason
return None


class TestOutputConsistency(unittest.TestCase):
"""Test output consistency between exported ONNX models and PyTorch eager mode.

Expand Down Expand Up @@ -235,7 +309,10 @@ def test_output_match(self, device: str, dtype: torch.dtype, op):
inputs=repr(inputs),
kwargs=repr(cpu_sample.kwargs),
):
input_numpy = [x.numpy() for x in inputs if isinstance(x, torch.Tensor)]
skip_reason = _should_skip_test_sample(op.name, cpu_sample)
if skip_reason is not None:
self.skipTest(skip_reason)
input_numpy = [_convert_tensor_to_numpy(x) for x in inputs]
torch_output = op(*inputs, **cpu_sample.kwargs)
try:
function_output = scripted_function(*input_numpy, **cpu_sample.kwargs)
Expand Down