Skip to content

feat(atenlib): create tests with OpInfo #208

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 20 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
20 commits
Select commit Hold shift + click to select a range
fa8ce18
feat(atenlib): Create sample functions and tests
justinchuby Nov 23, 2022
bb9fbee
Update on "feat(atenlib): Create sample functions and tests with OpInfo"
justinchuby Nov 23, 2022
91aa327
Update on "feat(atenlib): create tests with OpInfo"
justinchuby Nov 23, 2022
fce8072
Update on "feat(atenlib): create tests with OpInfo"
justinchuby Nov 23, 2022
8c0b370
Update on "feat(atenlib): create tests with OpInfo"
justinchuby Nov 23, 2022
8dc3d15
Update on "feat(atenlib): create tests with OpInfo"
justinchuby Nov 23, 2022
d87f309
Update on "feat(atenlib): create tests with OpInfo"
justinchuby Nov 23, 2022
c80ad0e
Update on "feat(atenlib): create tests with OpInfo"
justinchuby Nov 29, 2022
b730aa9
Update on "feat(atenlib): create tests with OpInfo"
justinchuby Nov 29, 2022
407da68
Update on "feat(atenlib): create tests with OpInfo"
justinchuby Nov 29, 2022
4c3e9c2
Update on "feat(atenlib): create tests with OpInfo"
justinchuby Nov 30, 2022
98bd90c
Update on "feat(atenlib): create tests with OpInfo"
justinchuby Nov 30, 2022
49fc7f1
Update on "feat(atenlib): create tests with OpInfo"
justinchuby Nov 30, 2022
63d4775
Update on "feat(atenlib): create tests with OpInfo"
justinchuby Nov 30, 2022
2c38122
Update on "feat(atenlib): create tests with OpInfo"
justinchuby Dec 5, 2022
d6dfd3d
Update on "feat(atenlib): create tests with OpInfo"
justinchuby Dec 5, 2022
69facc6
Update on "feat(atenlib): create tests with OpInfo"
justinchuby Dec 6, 2022
23771ab
Update on "feat(atenlib): create tests with OpInfo"
justinchuby Dec 6, 2022
a8507ec
Update on "feat(atenlib): create tests with OpInfo"
justinchuby Dec 6, 2022
2beb25e
Update on "feat(atenlib): create tests with OpInfo"
justinchuby Dec 6, 2022
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 5 additions & 2 deletions onnxscript/function_libs/torch_aten/ops/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
# Licensed under the MIT License.
# --------------------------------------------------------------------------
# mypy: disable-error-code=misc
# mypy: disable-error-code=arg-type
# mypy: disable-error-code=type-arg
# mypy: disable-error-code=valid-type
# mypy: disable-error-code=assignment
Expand All @@ -17,6 +18,8 @@
from typing import Any, Optional, Sequence

from onnxscript import INT64, TensorType
from onnxscript.function_libs.torch_aten.typing import TFloat
from onnxscript.onnx_opset import default_opset as op


def aten_abs(self: TensorType) -> TensorType:
Expand Down Expand Up @@ -4109,10 +4112,10 @@ def aten_select_scatter(self: TensorType, src: TensorType, dim: int, index: int)
raise NotImplementedError()


def aten_selu(self: TensorType) -> TensorType:
def aten_selu(self: TFloat) -> TensorType:
# selu(Tensor self) -> Tensor

raise NotImplementedError()
return op.Selu(self)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why do we put the logic like this within ONNX Script? As we discussed before, the logic relative to PyTorch specifically should be left in PyTorch exporter side, right?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In a general symbolic function of current PyTorch exporter, it mainly focuses on 2 things:

  1. Transform the aten inputs to onnx inputs and attributes.
  2. Find out a proper ONNX op for current aten op, or combine several ONNX op to implement the same function of given aten op.

By doing this, we will only leave the part 1 in PyTorch exporter. Is this expected?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

More of (2) should be in the function lib, so only the glue logic (not expressed by functions) lives in the exporter. We should aim for minimal op logic on the exporter (but allow full control at the same time so changes can be made independently)

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

About "minimal op logic on the exporter", may I know the reason why we are doing this? What's the difference between calling these troch_aten ops and onnx scripts ops from PyTorch exporter?



def aten_set_data(self: TensorType, new_data: TensorType) -> Any:
Expand Down
23 changes: 19 additions & 4 deletions onnxscript/function_libs/torch_aten/ops/nn.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
# Licensed under the MIT License.
# --------------------------------------------------------------------------
# mypy: disable-error-code=misc
# mypy: disable-error-code=arg-type
# mypy: disable-error-code=type-arg
# mypy: disable-error-code=valid-type
# mypy: disable-error-code=assignment
Expand All @@ -12,11 +13,19 @@
- All functions should not have the script() decorator. This is because
we want to delay the compilation of the function.
"""

# pylint: disable=unused-argument

from __future__ import annotations

from typing import Optional, Sequence

from beartype.vale import Is
from typing_extensions import Annotated

from onnxscript import INT64, TensorType
from onnxscript.function_libs.torch_aten.typing import TFloat
from onnxscript.onnx_opset import default_opset as op


def aten_adaptive_avg_pool2d(self: TensorType, output_size: INT64) -> TensorType:
Expand Down Expand Up @@ -185,11 +194,16 @@ def aten_cross_entropy_loss(


def aten_elu(
self: TensorType, alpha: float = 1, scale: float = 1, input_scale: float = 1
self: TFloat,
alpha: float = 1.0,
scale: Annotated[float, Is[lambda x: x == 1.0]] = 1.0,
input_scale: Annotated[float, Is[lambda x: x == 1.0]] = 1.0,
) -> TensorType:
# elu(Tensor self, Scalar alpha=1, Scalar scale=1, Scalar input_scale=1) -> Tensor

raise NotImplementedError()
# del scale
# del input_scale
return op.Elu(self, alpha=alpha)


def aten_elu_backward(
Expand Down Expand Up @@ -773,10 +787,11 @@ def aten_reflection_pad3d_backward(
raise NotImplementedError()


def aten_relu6(self: TensorType) -> TensorType:
# TODO(justinchuby): Use TFloat as return type
def aten_relu6(self: TFloat) -> TensorType:
# relu6(Tensor self) -> Tensor

raise NotImplementedError()
return op.Min(op.Relu(self), op.Constant(value_float=6.0)) # type: ignore[arg-type]


def aten_replication_pad1d(self: TensorType, padding: INT64) -> TensorType:
Expand Down
256 changes: 256 additions & 0 deletions onnxscript/test/function_libs/torch_aten/ops_correctness_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,256 @@
"""Test op correctness by comparing with PyTorch results."""
from __future__ import annotations

import copy
import dataclasses
import unittest
from typing import Any, Callable, Collection, Iterable, Optional, Sequence, TypeVar

import numpy as np
import onnxruntime.capi.onnxruntime_pybind11_state
import torch
from torch.testing._internal import common_device_type, common_methods_invocations
from torch.testing._internal.opinfo import core as opinfo_core

import onnxscript
from onnxscript.function_libs.torch_aten.ops import core as core_ops
from onnxscript.function_libs.torch_aten.ops import nn as nn_ops

T = TypeVar("T")

SUPPORTED_DTYPES = (
# Boolean
torch.bool,
# Integers
torch.uint8,
torch.int8,
torch.int16,
torch.int32,
torch.int64,
# Floating types
torch.float16,
torch.float32,
torch.float64,
)

# Convenience tuples for creating dtype lists when skipping or xfailing tests

BOOL_TYPES = (torch.bool,)

INT_TYPES = (
torch.int8,
torch.int16,
torch.int32,
torch.int64,
torch.uint8,
)

FLOAT_TYPES = (
torch.float16,
torch.float32,
torch.float64,
)


def dtypes_except(*dtypes: torch.dtype) -> Sequence[torch.dtype]:
"""Returns all dtypes except the ones specified."""
return tuple(dtype for dtype in SUPPORTED_DTYPES if dtype not in dtypes)


@dataclasses.dataclass
class DecorateMeta:
"""A dataclass for storing information about a test case to skip or xfail.

Adapted from functorch: functorch/test/common_utils.py
"""

op_name: str
variant_name: str
decorator: Callable[..., Any]
dtypes: Optional[Collection[torch.dtype]]
reason: str


def xfail(
op_name: str,
variant_name: str = "",
*,
dtypes: Optional[Collection[torch.dtype]] = None,
reason: Optional[str] = None,
):
"""Expects an OpInfo test to fail.

Args:
op_name: The name of the operator.
variant_name: Optional OpInfo variant_test_name.
dtypes: The dtypes to expect the failure.
reason: The reason for the failure.
"""
if reason is None:
raise ValueError("Please specify a reason.")
return DecorateMeta(
op_name=op_name,
variant_name=variant_name,
decorator=unittest.expectedFailure,
dtypes=dtypes,
reason=reason,
)


def skip(
op_name: str,
variant_name: str = "",
*,
dtypes: Optional[Collection[torch.dtype]] = None,
reason: Optional[str] = None,
):
"""Skips an OpInfo test.

Args:
op_name: The name of the operator.
variant_name: Optional OpInfo variant_test_name.
dtypes: The dtypes to skip.
reason: The reason for skipping.
"""
if reason is None:
raise ValueError("Please specify a reason.")
return DecorateMeta(
op_name=op_name,
variant_name=variant_name,
decorator=unittest.skip(f"Don't care: {reason}"),
dtypes=dtypes,
reason=reason,
)


def add_decorate_info(
all_opinfos: Sequence[opinfo_core.OpInfo],
test_class_name: str,
base_test_name: str,
skip_or_xfails: Iterable[DecorateMeta],
) -> Callable[[T], T]:
"""Decorates OpInfo tests with decorators based on the skip_or_xfails list."""
ops_mapping = {(info.name, info.variant_test_name): info for info in all_opinfos}
for decorate_meta in skip_or_xfails:
opinfo = ops_mapping.get((decorate_meta.op_name, decorate_meta.variant_name))
assert (
opinfo is not None
), f"Couldn't find OpInfo for {decorate_meta}. Did you need to specify variant_name?"
decorators = list(opinfo.decorators)
new_decorator = opinfo_core.DecorateInfo(
decorate_meta.decorator,
test_class_name,
base_test_name,
dtypes=decorate_meta.dtypes,
)
decorators.append(new_decorator)
opinfo.decorators = tuple(decorators)

# This decorator doesn't modify fn in any way
def wrapped(fn):
return fn

return wrapped


# Modify this section ##########################################################

# Ops to be tested for numerical consistency between onnx and pytorch
OPINFO_FUNCTION_MAPPING = {
"nn.functional.elu": nn_ops.aten_elu,
"nn.functional.relu6": nn_ops.aten_relu6,
"nn.functional.selu": core_ops.aten_selu,
}

TESTED_OPS = frozenset(OPINFO_FUNCTION_MAPPING)

EXPECTED_SKIPS_OR_FAILS = (
xfail(
"nn.functional.elu",
dtypes=dtypes_except(torch.float16, torch.float32),
reason="ONNX Runtime doesn't support float64 for Elu",
),
xfail(
"nn.functional.relu6",
dtypes=dtypes_except(torch.float16, torch.float32),
reason="ONNX Runtime doesn't support float64 for Relu",
),
xfail(
"nn.functional.selu",
dtypes=dtypes_except(torch.float16, torch.float32),
reason="ONNX Runtime doesn't support float64 for Selu",
),
)
# END OF SECTION TO MODIFY #####################################################


OPS_DB = copy.deepcopy(common_methods_invocations.op_db)


class TestOutputConsistency(unittest.TestCase):
"""Test output consistency between exported ONNX models and PyTorch eager mode.

This is a parameterized test suite.
"""

def setUp(self) -> None:
torch.manual_seed(42)
np.random.seed(42)

@common_device_type.ops( # type: ignore[misc]
[info for info in OPS_DB if info.name in TESTED_OPS],
allowed_dtypes=SUPPORTED_DTYPES,
)
@add_decorate_info(
OPS_DB,
"TestOutputConsistency",
"test_output_match",
skip_or_xfails=EXPECTED_SKIPS_OR_FAILS,
)
def test_output_match(self, device: str, dtype: torch.dtype, op):
"""Base test method for testing each opset, used by instantiate_device_type_tests."""
# device is provided by instantiate_device_type_tests, but we only want to run in cpu.
assert device == "cpu"

samples = op.sample_inputs(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why isn't there "shape" information for a sample input?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The shape is not explicitly used in this test because onnxscript can handle it in its evaluator. Any considerations?

device,
dtype,
requires_grad=False,
)

onnx_function = OPINFO_FUNCTION_MAPPING[op.name]
scripted_function = onnxscript.script()(onnx_function)

for (i, cpu_sample) in enumerate(samples):
inputs = (cpu_sample.input, *cpu_sample.args)
# Provide the repr to subtest because tensors are not serializable in parallel test runs
with self.subTest(
sample_num=i,
inputs=repr(inputs),
kwargs=repr(cpu_sample.kwargs),
):
input_numpy = [x.numpy() for x in inputs if isinstance(x, torch.Tensor)]
torch_output = op(*inputs, **cpu_sample.kwargs)
try:
function_output = scripted_function(*input_numpy, **cpu_sample.kwargs)
# pylint: disable=c-extension-no-member
except onnxruntime.capi.onnxruntime_pybind11_state.NotImplemented:
self.skipTest(
f"ONNX Runtime doesn't support running {op.name} with dtype {dtype}",
)
# pylint: enable=c-extension-no-member

# Use torch testing to ensure dtypes and shapes match
torch.testing.assert_close(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we need to consider the case of multiply outputs?

How about this:

assert [torch.allclose(o, torch.tensor(o_ort)) for o, o_ort in zip(torch_output, function_output)]

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good catch. I am inclined to keep it as is for now and expand to multi input when needed

torch.tensor(function_output),
torch_output,
)


common_device_type.instantiate_device_type_tests(
TestOutputConsistency, globals(), only_for="cpu"
)


if __name__ == "__main__":
unittest.main()
10 changes: 7 additions & 3 deletions requirements-dev.txt
Original file line number Diff line number Diff line change
Expand Up @@ -12,15 +12,19 @@ sphinx-gallery
pydata_sphinx_theme

# ATen lib
beartype
types-PyYAML
typing_extensions

# Testing
pytest!=7.1.0
pytest-cov
expecttest
parameterized
pytest-azurepipelines
pytest-cov
pytest-subtests
pytest-xdist
parameterized
pytest!=7.1.0
pyyaml
torch

# Lint
Expand Down