-
Notifications
You must be signed in to change notification settings - Fork 66
feat(atenlib): implement aten functions 1/n #247
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 5 commits
55f8dd9
8a3a587
f821b6a
aecc148
6555a55
f8385b0
468f86f
47b8380
00f1760
060f9db
497cb16
9bb4038
d24110a
cbfb867
875f235
27008e1
3a8737d
c5871c8
012905c
49be5ec
3a9c5f6
d4f09e8
ee3143e
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -17,8 +17,10 @@ | |
|
||
from typing import Any, Optional, Sequence | ||
|
||
import onnx.helper | ||
|
||
|
||
from onnxscript import BOOL, INT64 | ||
from onnxscript.onnx_opset import default_opset as op | ||
from onnxscript.onnx_opset import opset18 as op | ||
from onnxscript.onnx_types import TensorType | ||
|
||
|
||
|
@@ -747,16 +749,31 @@ def aten_clamp( | |
raise NotImplementedError() | ||
|
||
|
||
def aten_clamp_max(self: TensorType, max: float) -> TensorType: | ||
def aten_clamp_max_scalar(self, max_): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Is it possible to validate the shape of max_ to know it is a scalar or tensor instead of having such information in the function name? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I haven't thought of a good way. Any suggestions? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Ah, for validation, we can specify it in the signature / via some data structure. We can discuss today. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. They still need to be two functions though. |
||
# clamp_max(Tensor self, Scalar max) -> Tensor | ||
|
||
raise NotImplementedError() | ||
max_ = op.CastLike(max_, self) | ||
return op.Clip(self, None, max_) | ||
|
||
|
||
def aten_clamp_max_tensor(self, max_): | ||
# clamp_max(Tensor self, Scalar max) -> Tensor | ||
|
||
return op.Min(self, max_) | ||
|
||
|
||
def aten_clamp_min(self: TensorType, min: float) -> TensorType: | ||
def aten_clamp_min_scalar(self, min_): | ||
# clamp_min(Tensor self, Scalar min) -> Tensor | ||
# NOTE: min_ is a rank 0 tensor. | ||
# TODO(justinchuby): Specify the type constraints. | ||
min_ = op.CastLike(min_, self) | ||
return op.Clip(self, min_, None) | ||
|
||
raise NotImplementedError() | ||
|
||
def aten_clamp_min_tensor(self, min_): | ||
# clamp_min(Tensor self, Tensor min) -> Tensor | ||
# TODO(justinchuby): Specify the type constraints. | ||
return op.Max(self, min_) | ||
|
||
|
||
def aten_clip( | ||
|
@@ -1958,10 +1975,12 @@ def aten_gru_cell( | |
raise NotImplementedError() | ||
|
||
|
||
def aten_gt(self: TensorType, other: TensorType) -> TensorType: | ||
def aten_gt(self, other): | ||
# gt.Tensor(Tensor self, Tensor other) -> Tensor | ||
|
||
raise NotImplementedError() | ||
# TODO(justinchuby): Input spec: non bool tensor | ||
# Boolean inputs can be pre-casted by policy | ||
return op.Greater(self, other) | ||
|
||
|
||
def aten_hamming_window(window_length: int) -> TensorType: | ||
|
@@ -2572,10 +2591,12 @@ def aten_lstm_mps_backward( | |
raise NotImplementedError() | ||
|
||
|
||
def aten_lt(self: TensorType, other: TensorType) -> TensorType: | ||
def aten_lt(self, other): | ||
# lt.Tensor(Tensor self, Tensor other) -> Tensor | ||
|
||
raise NotImplementedError() | ||
# TODO(justinchuby): Input spec: non bool tensor | ||
# Boolean inputs can be pre-casted by policy | ||
return op.Less(self, other) | ||
|
||
|
||
def aten_lu_solve(self: TensorType, LU_data: TensorType, LU_pivots: TensorType) -> TensorType: | ||
|
@@ -3440,10 +3461,15 @@ def aten_ones(size: INT64) -> TensorType: | |
raise NotImplementedError() | ||
|
||
|
||
def aten_ones_like(self: TensorType, memory_format: Optional[str] = None) -> TensorType: | ||
def aten_ones_like(self, dtype: int = onnx.TensorProto.FLOAT): | ||
"""ones_like. | ||
|
||
Note: dtype is an onnx enum. Users should convert torch dtype to onnx dtype | ||
before calling this function. | ||
""" | ||
# ones_like(Tensor self, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None, MemoryFormat? memory_format=None) -> Tensor | ||
|
||
raise NotImplementedError() | ||
return common.ones_like(self, dtype) | ||
|
||
|
||
|
||
def aten_or(self: TensorType, other: TensorType) -> TensorType: | ||
|
@@ -3916,10 +3942,19 @@ def aten_renorm(self: TensorType, p: float, dim: int, maxnorm: float) -> TensorT | |
raise NotImplementedError() | ||
|
||
|
||
def aten_repeat(self: TensorType, repeats: INT64) -> TensorType: | ||
def aten_repeat(self, repeats: INT64["M"]): | ||
|
||
# repeat(Tensor self, SymInt[] repeats) -> Tensor | ||
|
||
raise NotImplementedError() | ||
# FIXME(justinchuby): When repeats.shape == [0] | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This is an example where we need shape dependent logic. |
||
|
||
# TODO(justinchuby): Make ones_like a function when onnxscript supports it | ||
# shape = ones_like(repeats) := { | ||
one = op.Constant(value_int=1) | ||
repeats_shape = op.Shape(repeats) | ||
shape = op.Expand(one, repeats_shape) | ||
# } | ||
self_expanded = op.Expand(self, shape) | ||
|
||
return op.Tile(self_expanded, repeats) | ||
|
||
|
||
def aten_repeat_interleave( | ||
|
@@ -4012,10 +4047,10 @@ def aten_rot90(self: TensorType, k: int = 1, dims: Sequence[int] = (0, 1)) -> Te | |
raise NotImplementedError() | ||
|
||
|
||
def aten_round(self: TensorType) -> TensorType: | ||
def aten_round(self): | ||
# round(Tensor self) -> Tensor | ||
|
||
raise NotImplementedError() | ||
return op.Round(self) | ||
|
||
|
||
def aten_row_indices(self: TensorType) -> TensorType: | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -8,6 +8,7 @@ | |
|
||
import numpy as np | ||
import onnxruntime.capi.onnxruntime_pybind11_state | ||
import parameterized | ||
import torch | ||
from torch.testing._internal import common_device_type, common_methods_invocations | ||
from torch.testing._internal.opinfo import core as opinfo_core | ||
|
@@ -69,14 +70,15 @@ class DecorateMeta: | |
decorator: Callable[..., Any] | ||
dtypes: Optional[Collection[torch.dtype]] | ||
reason: str | ||
matcher: Optional[Callable[[Any], bool]] = None | ||
|
||
|
||
def xfail( | ||
op_name: str, | ||
variant_name: str = "", | ||
*, | ||
reason: str, | ||
dtypes: Optional[Collection[torch.dtype]] = None, | ||
reason: Optional[str] = None, | ||
): | ||
"""Expects an OpInfo test to fail. | ||
|
||
|
@@ -86,8 +88,6 @@ def xfail( | |
dtypes: The dtypes to expect the failure. | ||
reason: The reason for the failure. | ||
""" | ||
if reason is None: | ||
raise ValueError("Please specify a reason.") | ||
return DecorateMeta( | ||
op_name=op_name, | ||
variant_name=variant_name, | ||
|
@@ -101,8 +101,9 @@ def skip( | |
op_name: str, | ||
variant_name: str = "", | ||
*, | ||
reason: str, | ||
dtypes: Optional[Collection[torch.dtype]] = None, | ||
reason: Optional[str] = None, | ||
matcher: Optional[Callable[[Any], Any]] = None, | ||
): | ||
"""Skips an OpInfo test. | ||
|
||
|
@@ -111,15 +112,16 @@ def skip( | |
variant_name: Optional OpInfo variant_test_name. | ||
dtypes: The dtypes to skip. | ||
reason: The reason for skipping. | ||
matcher: A function that matches the test sample input. It is used only when | ||
xfail is in the SKIP_SUBTESTS list. | ||
""" | ||
if reason is None: | ||
raise ValueError("Please specify a reason.") | ||
return DecorateMeta( | ||
op_name=op_name, | ||
variant_name=variant_name, | ||
decorator=unittest.skip(f"Don't care: {reason}"), | ||
dtypes=dtypes, | ||
reason=reason, | ||
matcher=matcher, | ||
) | ||
|
||
|
||
|
@@ -156,19 +158,26 @@ def wrapped(fn): | |
# Modify this section ########################################################## | ||
|
||
# Ops to be tested for numerical consistency between onnx and pytorch | ||
OPINFO_FUNCTION_MAPPING = { | ||
# Find the names of the OpInfos in torch/testing/_internal/common_methods_invocations.py | ||
OPINFO_FUNCTION_MAPPING: dict[str, Callable[..., Any]] = { | ||
"add": core_ops.aten_add, | ||
"gt": core_ops.aten_gt, | ||
"lt": core_ops.aten_lt, | ||
"mul": core_ops.aten_mul, | ||
"nn.functional.elu": nn_ops.aten_elu, | ||
"nn.functional.relu6": nn_ops.aten_relu6, | ||
"nn.functional.selu": core_ops.aten_selu, | ||
"repeat": core_ops.aten_repeat, | ||
"round": core_ops.aten_round, | ||
"sub": core_ops.aten_sub, | ||
} | ||
|
||
TESTED_OPS = frozenset(OPINFO_FUNCTION_MAPPING) | ||
|
||
EXPECTED_SKIPS_OR_FAILS = ( | ||
xfail("add", dtypes=BOOL_TYPES, reason="Add is not defined on bool tensors"), | ||
xfail("gt", dtypes=BOOL_TYPES, reason="Greater is not defined on bool tensors"), | ||
xfail("lt", dtypes=BOOL_TYPES, reason="Less is not defined on bool tensors"), | ||
xfail("mul", dtypes=BOOL_TYPES, reason="Mul is not defined on bool tensors"), | ||
xfail( | ||
"nn.functional.elu", | ||
|
@@ -185,14 +194,79 @@ def wrapped(fn): | |
dtypes=dtypes_except(torch.float16, torch.float32), | ||
reason="ONNX Runtime doesn't support float64 for Selu", | ||
), | ||
# xfail("repeat", reason="fails when repeats is empty."), | ||
xfail( | ||
"round", | ||
variant_name="", | ||
dtypes=dtypes_except(*FLOAT_TYPES), | ||
reason="Round is not defined on non-float tensors", | ||
), | ||
xfail("round", variant_name="decimals_0", reason="The ATen op does not support decimals"), | ||
xfail("round", variant_name="decimals_3", reason="The ATen op does not support decimals"), | ||
xfail( | ||
"round", variant_name="decimals_neg_3", reason="The ATen op does not support decimals" | ||
), | ||
xfail("sub", dtypes=BOOL_TYPES, reason="Sub is not defined on bool tensors"), | ||
) | ||
|
||
|
||
SKIP_SUBTESTS = ( | ||
skip( | ||
"repeat", | ||
reason="repeating when input is a scalar and repeats is empty is not supported.", | ||
matcher=lambda sample: sample.args[0] == (), | ||
), | ||
) | ||
OP_WITH_SKIPPED_SUBTESTS = frozenset( | ||
meta.op_name for meta in SKIP_SUBTESTS | ||
) | ||
|
||
# END OF SECTION TO MODIFY ##################################################### | ||
|
||
|
||
OPS_DB = copy.deepcopy(common_methods_invocations.op_db) | ||
|
||
|
||
class TestFunctionsCompilation(unittest.TestCase): | ||
"""Test all functions can be compiled.""" | ||
|
||
@parameterized.parameterized.expand( | ||
list(OPINFO_FUNCTION_MAPPING.items()), | ||
) | ||
def test_function_compiles(self, _, function): | ||
compiled = onnxscript.script()(function) | ||
compiled.to_function_proto() | ||
|
||
|
||
def _convert_tensor_to_numpy(input: Any) -> Any: | ||
if isinstance(input, torch.Tensor): | ||
return input.detach().cpu().numpy() | ||
if isinstance(input, (tuple, list)): | ||
if len(input) == 0: | ||
return np.array((), dtype=np.int64) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Why is not the default value float32? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I found that it’s usually an index tuple so it needs to be int. If there are more subtle cases we can update the logic to handle those |
||
if isinstance(input[0], torch.Tensor): | ||
return [_convert_tensor_to_numpy(x) for x in input] | ||
if isinstance(input[0], (int, float)): | ||
# Just a tuple of numbers | ||
return np.array(input) | ||
return input | ||
|
||
return input | ||
|
||
|
||
def _should_skip_test_sample(op_name: str, sample) -> Optional[str]: | ||
"""Returns a reason if a test sample should be skipped.""" | ||
if op_name not in OP_WITH_SKIPPED_SUBTESTS: | ||
return None | ||
for decorator_meta in SKIP_SUBTESTS: | ||
# Linear search on SKIP_SUBTESTS. That's fine because the list is small. | ||
if decorator_meta.op_name == op_name: | ||
assert decorator_meta.matcher is not None, "Matcher must be defined" | ||
if decorator_meta.matcher(sample): | ||
return decorator_meta.reason | ||
|
||
return None | ||
|
||
|
||
class TestOutputConsistency(unittest.TestCase): | ||
"""Test output consistency between exported ONNX models and PyTorch eager mode. | ||
|
||
|
@@ -235,7 +309,10 @@ def test_output_match(self, device: str, dtype: torch.dtype, op): | |
inputs=repr(inputs), | ||
kwargs=repr(cpu_sample.kwargs), | ||
): | ||
input_numpy = [x.numpy() for x in inputs if isinstance(x, torch.Tensor)] | ||
skip_reason = _should_skip_test_sample(op.name, cpu_sample) | ||
if skip_reason is not None: | ||
self.skipTest(skip_reason) | ||
input_numpy = [_convert_tensor_to_numpy(x) for x in inputs] | ||
torch_output = op(*inputs, **cpu_sample.kwargs) | ||
try: | ||
function_output = scripted_function(*input_numpy, **cpu_sample.kwargs) | ||
|
Uh oh!
There was an error while loading. Please reload this page.