-
Notifications
You must be signed in to change notification settings - Fork 72
feat(atenlib): create tests with OpInfo #208
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 13 commits
fa8ce18
bb9fbee
91aa327
fce8072
8c0b370
8dc3d15
d87f309
c80ad0e
b730aa9
407da68
4c3e9c2
98bd90c
49fc7f1
63d4775
2c38122
d6dfd3d
69facc6
23771ab
a8507ec
2beb25e
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -12,7 +12,12 @@ | |
|
||
from typing import Optional, Sequence | ||
|
||
from beartype.vale import Is | ||
from typing_extensions import Annotated | ||
|
||
from onnxscript import INT64, TensorType | ||
from onnxscript.function_libs.torch_aten.typing import FloatType | ||
from onnxscript.onnx_opset import opset18 as op | ||
|
||
|
||
def aten_adaptive_avg_pool2d(self: TensorType, output_size: INT64[2]) -> TensorType: | ||
|
@@ -181,11 +186,28 @@ def aten_cross_entropy_loss( | |
|
||
|
||
def aten_elu( | ||
self: TensorType, alpha: float = 1, scale: float = 1, input_scale: float = 1 | ||
self: FloatType, | ||
alpha: float = 1.0, | ||
scale: Annotated[float, Is[lambda x: x == 1.0]] = 1.0, | ||
input_scale: Annotated[float, Is[lambda x: x == 1.0]] = 1.0, | ||
) -> TensorType: | ||
# elu(Tensor self, Scalar alpha=1, Scalar scale=1, Scalar input_scale=1) -> Tensor | ||
|
||
raise NotImplementedError() | ||
# del scale | ||
# del input_scale | ||
return op.Elu(self, alpha=alpha) | ||
|
||
|
||
def aten_elu__int( | ||
self: IntType, | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Even these are designed for aten ops, I still suggest the function signature is ONNX compatible only. We don't need to support any Torch data type which is unsupported by ONNX in these function signatures. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Sounds like a good idea to me There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. @titaiwangms is offering to help think about this aspect.
|
||
alpha: float = 1.0, | ||
scale: Annotated[float, Is[lambda x: x == 1.0]] = 1.0, | ||
input_scale: Annotated[float, Is[lambda x: x == 1.0]] = 1.0, | ||
) -> TensorType: | ||
# TODO(justinchuby): Move the type casting logic to exporter? | ||
# del scale | ||
# del input_scale | ||
return op.Elu(op.Cast(self, to=onnxscript.FLOAT), alpha=alpha) | ||
|
||
|
||
|
||
def aten_elu_backward( | ||
|
@@ -769,10 +791,13 @@ def aten_reflection_pad3d_backward( | |
raise NotImplementedError() | ||
|
||
|
||
def aten_relu6(self: TensorType) -> TensorType: | ||
def aten_relu6(self: FloatType) -> FloatType: | ||
# relu6(Tensor self) -> Tensor | ||
|
||
raise NotImplementedError() | ||
# TODO(justinchuby): Create a shortcut for creating constants | ||
zero = op.CastLike(op.Constant(value_float=0.0), self) | ||
|
||
# zero = op.CastLike(0, self) | ||
return op.Max(self, zero) | ||
|
||
|
||
|
||
def aten_replication_pad1d(self: TensorType, padding: INT64[2]) -> TensorType: | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,10 @@ | ||
# -------------------------------------------------------------------------- | ||
# Copyright (c) Microsoft Corporation. All rights reserved. | ||
# Licensed under the MIT License. | ||
# -------------------------------------------------------------------------- | ||
from __future__ import annotations | ||
|
||
from onnxscript import DOUBLE, FLOAT, FLOAT16, INT16, INT32, INT64, TensorType | ||
|
||
FloatType = FLOAT16 | FLOAT | DOUBLE | ||
IntType = INT16 | INT32 | INT64 |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,228 @@ | ||
"""Test op correctness by comparing with PyTorch results.""" | ||
|
||
from __future__ import annotations | ||
|
||
import copy | ||
import dataclasses | ||
import unittest | ||
from typing import Callable, Collection, Iterable, Optional, Sequence | ||
|
||
import numpy as np | ||
import torch | ||
from torch.testing._internal import common_device_type, common_methods_invocations | ||
from torch.testing._internal.opinfo import core as opinfo_core | ||
|
||
import onnxscript | ||
from onnxscript.function_libs.torch_aten.ops import core as core_ops | ||
from onnxscript.function_libs.torch_aten.ops import nn as nn_ops | ||
|
||
SUPPORTED_DTYPES = ( | ||
# Boolean | ||
torch.bool, | ||
# Integers | ||
torch.uint8, | ||
torch.int8, | ||
torch.int16, | ||
torch.int32, | ||
torch.int64, | ||
# Floating types | ||
torch.float16, | ||
torch.float32, | ||
torch.float64, | ||
) | ||
|
||
# Convenience tuples for creating dtype lists when skipping or xfailing tests | ||
|
||
BOOL_TYPES = (torch.bool,) | ||
|
||
INT_TYPES = ( | ||
torch.int8, | ||
torch.int16, | ||
torch.int32, | ||
torch.int64, | ||
torch.uint8, | ||
) | ||
|
||
FLOAT_TYPES = ( | ||
torch.float16, | ||
torch.float32, | ||
torch.float64, | ||
) | ||
|
||
|
||
@dataclasses.dataclass | ||
class DecorateMeta: | ||
"""A dataclass for storing information about a test case to skip or xfail. | ||
|
||
Adapted from functorch: functorch/test/common_utils.py | ||
""" | ||
|
||
op_name: str | ||
variant_name: str | ||
decorator: Callable | ||
|
||
dtypes: Optional[Collection[torch.dtype]] | ||
reason: str | ||
|
||
|
||
def xfail( | ||
op_name: str, | ||
variant_name: str = "", | ||
*, | ||
dtypes: Optional[Collection[torch.dtype]] = None, | ||
reason: Optional[str] = None, | ||
): | ||
"""Expects an OpInfo test to fail. | ||
|
||
Args: | ||
op_name: The name of the operator. | ||
variant_name: Optional OpInfo variant_test_name. | ||
dtypes: The dtypes to expect the failure. | ||
reason: The reason for the failure. | ||
""" | ||
if reason is None: | ||
raise ValueError("Please specify a reason.") | ||
return DecorateMeta( | ||
op_name=op_name, | ||
variant_name=variant_name, | ||
decorator=unittest.expectedFailure, | ||
dtypes=dtypes, | ||
reason=reason, | ||
) | ||
|
||
|
||
def skip( | ||
op_name: str, | ||
variant_name: str = "", | ||
*, | ||
dtypes: Optional[Collection[torch.dtype]] = None, | ||
reason: Optional[str] = None, | ||
): | ||
"""Skips an OpInfo test. | ||
|
||
Args: | ||
op_name: The name of the operator. | ||
variant_name: Optional OpInfo variant_test_name. | ||
dtypes: The dtypes to skip. | ||
reason: The reason for skipping. | ||
""" | ||
if reason is None: | ||
raise ValueError("Please specify a reason.") | ||
return DecorateMeta( | ||
op_name=op_name, | ||
variant_name=variant_name, | ||
decorator=unittest.skip(f"Don't care: {reason}"), | ||
dtypes=dtypes, | ||
reason=reason, | ||
) | ||
|
||
|
||
def add_decorate_info( | ||
all_opinfos: Sequence[opinfo_core.OpInfo], | ||
test_class_name: str, | ||
base_test_name: str, | ||
skip_or_xfails: Iterable[DecorateMeta], | ||
justinchuby marked this conversation as resolved.
Show resolved
Hide resolved
|
||
): | ||
justinchuby marked this conversation as resolved.
Show resolved
Hide resolved
|
||
"""Decorates OpInfo tests with decorators based on the skip_or_xfails list.""" | ||
ops_mapping = {(info.name, info.variant_test_name): info for info in all_opinfos} | ||
for decorate_meta in skip_or_xfails: | ||
opinfo = ops_mapping.get((decorate_meta.op_name, decorate_meta.variant_name)) | ||
assert ( | ||
opinfo is not None | ||
), f"Couldn't find OpInfo for {decorate_meta}. Did you need to specify variant_name?" | ||
decorators = list(opinfo.decorators) | ||
new_decorator = opinfo_core.DecorateInfo( | ||
decorate_meta.decorator, | ||
test_class_name, | ||
base_test_name, | ||
dtypes=decorate_meta.dtypes, | ||
) | ||
decorators.append(new_decorator) | ||
opinfo.decorators = tuple(decorators) | ||
|
||
# This decorator doesn't modify fn in any way | ||
def wrapped(fn): | ||
return fn | ||
|
||
return wrapped | ||
|
||
|
||
# Modify this section ########################################################## | ||
|
||
# Ops to be tested for numerical consistency between onnx and pytorch | ||
OPINFO_FUNCTION_MAPPING = { | ||
"nn.functional.elu": nn_ops.aten_elu, | ||
"nn.functional.relu6": nn_ops.aten_relu6, | ||
"nn.functional.selu": core_ops.aten_selu, | ||
} | ||
|
||
TESTED_OPS = frozenset(OPINFO_FUNCTION_MAPPING) | ||
|
||
EXPECTED_SKIPS_OR_FAILS = ( | ||
xfail( | ||
"nn.functional.elu", dtypes=[torch.float64], reason="ORT does not support Elu float64" | ||
), | ||
) | ||
# END OF SECTION TO MODIFY ##################################################### | ||
|
||
|
||
OPS_DB = copy.deepcopy(common_methods_invocations.op_db) | ||
|
||
|
||
class TestOutputConsistency(unittest.TestCase): | ||
"""Test output consistency between exported ONNX models and PyTorch eager mode. | ||
|
||
This is a parameterized test suite. | ||
""" | ||
|
||
def setUp(self) -> None: | ||
torch.manual_seed(42) | ||
np.random.seed(42) | ||
|
||
@common_device_type.ops( | ||
|
||
[info for info in OPS_DB if info.name in TESTED_OPS], | ||
allowed_dtypes=SUPPORTED_DTYPES, | ||
) | ||
@add_decorate_info( | ||
|
||
OPS_DB, | ||
"TestOutputConsistency", | ||
"test_output_match", | ||
skip_or_xfails=EXPECTED_SKIPS_OR_FAILS, | ||
) | ||
def test_output_match(self, device: str, dtype: torch.dtype, op): | ||
"""Base test method for testing each opset, used by instantiate_device_type_tests.""" | ||
# device is provided by instantiate_device_type_tests, but we only want to run in cpu. | ||
assert device == "cpu" | ||
|
||
samples = op.sample_inputs( | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Why isn't there "shape" information for a sample input? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The shape is not explicitly used in this test because onnxscript can handle it in its evaluator. Any considerations? |
||
device, | ||
dtype, | ||
requires_grad=False, | ||
) | ||
|
||
onnx_function = OPINFO_FUNCTION_MAPPING[op.name] | ||
scripted_function = onnxscript.script()(onnx_function) | ||
|
||
for (i, cpu_sample) in enumerate(samples): | ||
inputs = (cpu_sample.input, *cpu_sample.args) | ||
# Provide the repr to subtest because tensors are not serializable in parallel test runs | ||
with self.subTest( | ||
sample_num=i, | ||
inputs=repr(inputs), | ||
kwargs=repr(cpu_sample.kwargs), | ||
): | ||
input_numpy = [x.numpy() for x in inputs if isinstance(x, torch.Tensor)] | ||
function_output = scripted_function(*input_numpy, **cpu_sample.kwargs) | ||
torch_output = op(*inputs, **cpu_sample.kwargs) | ||
|
||
np.testing.assert_allclose( | ||
function_output, | ||
torch_output.numpy(), | ||
) | ||
|
||
|
||
common_device_type.instantiate_device_type_tests( | ||
TestOutputConsistency, globals(), only_for="cpu" | ||
) | ||
|
||
|
||
if __name__ == "__main__": | ||
unittest.main() |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -12,6 +12,8 @@ sphinx-gallery | |
pydata_sphinx_theme | ||
|
||
# ATen lib | ||
typing_extensions | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Could you re-sort them in alphabetical order? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Done |
||
beartype | ||
types-PyYAML | ||
|
||
# Testing | ||
|
@@ -22,6 +24,8 @@ pytest-subtests | |
pytest-xdist | ||
parameterized | ||
torch | ||
expecttest | ||
pyyaml | ||
|
||
# Lint | ||
lintrunner | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why do we put the logic like this within ONNX Script? As we discussed before, the logic relative to PyTorch specifically should be left in PyTorch exporter side, right?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
In a general symbolic function of current PyTorch exporter, it mainly focuses on 2 things:
By doing this, we will only leave the part 1 in PyTorch exporter. Is this expected?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
More of (2) should be in the function lib, so only the glue logic (not expressed by functions) lives in the exporter. We should aim for minimal op logic on the exporter (but allow full control at the same time so changes can be made independently)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
About "minimal op logic on the exporter", may I know the reason why we are doing this? What's the difference between calling these troch_aten ops and onnx scripts ops from PyTorch exporter?