Skip to content

feat(atenlib): create tests with OpInfo #208

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 20 commits into from
Closed
Show file tree
Hide file tree
Changes from 13 commits
Commits
Show all changes
20 commits
Select commit Hold shift + click to select a range
fa8ce18
feat(atenlib): Create sample functions and tests
justinchuby Nov 23, 2022
bb9fbee
Update on "feat(atenlib): Create sample functions and tests with OpInfo"
justinchuby Nov 23, 2022
91aa327
Update on "feat(atenlib): create tests with OpInfo"
justinchuby Nov 23, 2022
fce8072
Update on "feat(atenlib): create tests with OpInfo"
justinchuby Nov 23, 2022
8c0b370
Update on "feat(atenlib): create tests with OpInfo"
justinchuby Nov 23, 2022
8dc3d15
Update on "feat(atenlib): create tests with OpInfo"
justinchuby Nov 23, 2022
d87f309
Update on "feat(atenlib): create tests with OpInfo"
justinchuby Nov 23, 2022
c80ad0e
Update on "feat(atenlib): create tests with OpInfo"
justinchuby Nov 29, 2022
b730aa9
Update on "feat(atenlib): create tests with OpInfo"
justinchuby Nov 29, 2022
407da68
Update on "feat(atenlib): create tests with OpInfo"
justinchuby Nov 29, 2022
4c3e9c2
Update on "feat(atenlib): create tests with OpInfo"
justinchuby Nov 30, 2022
98bd90c
Update on "feat(atenlib): create tests with OpInfo"
justinchuby Nov 30, 2022
49fc7f1
Update on "feat(atenlib): create tests with OpInfo"
justinchuby Nov 30, 2022
63d4775
Update on "feat(atenlib): create tests with OpInfo"
justinchuby Nov 30, 2022
2c38122
Update on "feat(atenlib): create tests with OpInfo"
justinchuby Dec 5, 2022
d6dfd3d
Update on "feat(atenlib): create tests with OpInfo"
justinchuby Dec 5, 2022
69facc6
Update on "feat(atenlib): create tests with OpInfo"
justinchuby Dec 6, 2022
23771ab
Update on "feat(atenlib): create tests with OpInfo"
justinchuby Dec 6, 2022
a8507ec
Update on "feat(atenlib): create tests with OpInfo"
justinchuby Dec 6, 2022
2beb25e
Update on "feat(atenlib): create tests with OpInfo"
justinchuby Dec 6, 2022
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 4 additions & 2 deletions onnxscript/function_libs/torch_aten/ops/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,8 @@
from typing import Any, Optional, Sequence

from onnxscript import INT64, TensorType
from onnxscript.function_libs.torch_aten.typing import FloatType
from onnxscript.onnx_opset import opset18 as op


def aten_abs(self: TensorType) -> TensorType:
Expand Down Expand Up @@ -3854,10 +3856,10 @@ def aten_select_scatter(self: TensorType, src: TensorType, dim: int, index: int)
raise NotImplementedError()


def aten_selu(self: TensorType) -> TensorType:
def aten_selu(self: FloatType) -> TensorType:
# selu(Tensor self) -> Tensor

raise NotImplementedError()
return op.Selu(self)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why do we put the logic like this within ONNX Script? As we discussed before, the logic relative to PyTorch specifically should be left in PyTorch exporter side, right?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In a general symbolic function of current PyTorch exporter, it mainly focuses on 2 things:

  1. Transform the aten inputs to onnx inputs and attributes.
  2. Find out a proper ONNX op for current aten op, or combine several ONNX op to implement the same function of given aten op.

By doing this, we will only leave the part 1 in PyTorch exporter. Is this expected?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

More of (2) should be in the function lib, so only the glue logic (not expressed by functions) lives in the exporter. We should aim for minimal op logic on the exporter (but allow full control at the same time so changes can be made independently)

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

About "minimal op logic on the exporter", may I know the reason why we are doing this? What's the difference between calling these troch_aten ops and onnx scripts ops from PyTorch exporter?



def aten_set_data(self: TensorType, new_data: TensorType) -> Any:
Expand Down
33 changes: 29 additions & 4 deletions onnxscript/function_libs/torch_aten/ops/nn.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,12 @@

from typing import Optional, Sequence

from beartype.vale import Is
from typing_extensions import Annotated

from onnxscript import INT64, TensorType
from onnxscript.function_libs.torch_aten.typing import FloatType
from onnxscript.onnx_opset import opset18 as op


def aten_adaptive_avg_pool2d(self: TensorType, output_size: INT64[2]) -> TensorType:
Expand Down Expand Up @@ -181,11 +186,28 @@ def aten_cross_entropy_loss(


def aten_elu(
self: TensorType, alpha: float = 1, scale: float = 1, input_scale: float = 1
self: FloatType,
alpha: float = 1.0,
scale: Annotated[float, Is[lambda x: x == 1.0]] = 1.0,
input_scale: Annotated[float, Is[lambda x: x == 1.0]] = 1.0,
) -> TensorType:
# elu(Tensor self, Scalar alpha=1, Scalar scale=1, Scalar input_scale=1) -> Tensor

raise NotImplementedError()
# del scale
# del input_scale
return op.Elu(self, alpha=alpha)


def aten_elu__int(
self: IntType,
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@fatcat-z @xiaowuhu thoughts on reconciling data type that is not supported by the onnx native op?

Copy link
Contributor

@fatcat-z fatcat-z Nov 30, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Even these are designed for aten ops, I still suggest the function signature is ONNX compatible only. We don't need to support any Torch data type which is unsupported by ONNX in these function signatures.
We need to transfer those data types to ONNX compatible ones (Numpy/Python) before we call these aten_op() functions. If necessary, we might provide a helper lib which helps to convert data from PyTorch data format to ONNX data format.
Such lib could be a new Python package or lives in ONNX Script package.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sounds like a good idea to me

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@titaiwangms is offering to help think about this aspect.

alpha: float = 1.0,
scale: Annotated[float, Is[lambda x: x == 1.0]] = 1.0,
input_scale: Annotated[float, Is[lambda x: x == 1.0]] = 1.0,
) -> TensorType:
# TODO(justinchuby): Move the type casting logic to exporter?
# del scale
# del input_scale
return op.Elu(op.Cast(self, to=onnxscript.FLOAT), alpha=alpha)


def aten_elu_backward(
Expand Down Expand Up @@ -769,10 +791,13 @@ def aten_reflection_pad3d_backward(
raise NotImplementedError()


def aten_relu6(self: TensorType) -> TensorType:
def aten_relu6(self: FloatType) -> FloatType:
# relu6(Tensor self) -> Tensor

raise NotImplementedError()
# TODO(justinchuby): Create a shortcut for creating constants
zero = op.CastLike(op.Constant(value_float=0.0), self)
# zero = op.CastLike(0, self)
return op.Max(self, zero)


def aten_replication_pad1d(self: TensorType, padding: INT64[2]) -> TensorType:
Expand Down
10 changes: 10 additions & 0 deletions onnxscript/function_libs/torch_aten/typing.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
# --------------------------------------------------------------------------
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License.
# --------------------------------------------------------------------------
from __future__ import annotations

from onnxscript import DOUBLE, FLOAT, FLOAT16, INT16, INT32, INT64, TensorType

FloatType = FLOAT16 | FLOAT | DOUBLE
IntType = INT16 | INT32 | INT64
228 changes: 228 additions & 0 deletions onnxscript/test/function_libs/torch_aten/ops_correctness_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,228 @@
"""Test op correctness by comparing with PyTorch results."""
from __future__ import annotations

import copy
import dataclasses
import unittest
from typing import Callable, Collection, Iterable, Optional, Sequence

import numpy as np
import torch
from torch.testing._internal import common_device_type, common_methods_invocations
from torch.testing._internal.opinfo import core as opinfo_core

import onnxscript
from onnxscript.function_libs.torch_aten.ops import core as core_ops
from onnxscript.function_libs.torch_aten.ops import nn as nn_ops

SUPPORTED_DTYPES = (
# Boolean
torch.bool,
# Integers
torch.uint8,
torch.int8,
torch.int16,
torch.int32,
torch.int64,
# Floating types
torch.float16,
torch.float32,
torch.float64,
)

# Convenience tuples for creating dtype lists when skipping or xfailing tests

BOOL_TYPES = (torch.bool,)

INT_TYPES = (
torch.int8,
torch.int16,
torch.int32,
torch.int64,
torch.uint8,
)

FLOAT_TYPES = (
torch.float16,
torch.float32,
torch.float64,
)


@dataclasses.dataclass
class DecorateMeta:
"""A dataclass for storing information about a test case to skip or xfail.

Adapted from functorch: functorch/test/common_utils.py
"""

op_name: str
variant_name: str
decorator: Callable
dtypes: Optional[Collection[torch.dtype]]
reason: str


def xfail(
op_name: str,
variant_name: str = "",
*,
dtypes: Optional[Collection[torch.dtype]] = None,
reason: Optional[str] = None,
):
"""Expects an OpInfo test to fail.

Args:
op_name: The name of the operator.
variant_name: Optional OpInfo variant_test_name.
dtypes: The dtypes to expect the failure.
reason: The reason for the failure.
"""
if reason is None:
raise ValueError("Please specify a reason.")
return DecorateMeta(
op_name=op_name,
variant_name=variant_name,
decorator=unittest.expectedFailure,
dtypes=dtypes,
reason=reason,
)


def skip(
op_name: str,
variant_name: str = "",
*,
dtypes: Optional[Collection[torch.dtype]] = None,
reason: Optional[str] = None,
):
"""Skips an OpInfo test.

Args:
op_name: The name of the operator.
variant_name: Optional OpInfo variant_test_name.
dtypes: The dtypes to skip.
reason: The reason for skipping.
"""
if reason is None:
raise ValueError("Please specify a reason.")
return DecorateMeta(
op_name=op_name,
variant_name=variant_name,
decorator=unittest.skip(f"Don't care: {reason}"),
dtypes=dtypes,
reason=reason,
)


def add_decorate_info(
all_opinfos: Sequence[opinfo_core.OpInfo],
test_class_name: str,
base_test_name: str,
skip_or_xfails: Iterable[DecorateMeta],
):
"""Decorates OpInfo tests with decorators based on the skip_or_xfails list."""
ops_mapping = {(info.name, info.variant_test_name): info for info in all_opinfos}
for decorate_meta in skip_or_xfails:
opinfo = ops_mapping.get((decorate_meta.op_name, decorate_meta.variant_name))
assert (
opinfo is not None
), f"Couldn't find OpInfo for {decorate_meta}. Did you need to specify variant_name?"
decorators = list(opinfo.decorators)
new_decorator = opinfo_core.DecorateInfo(
decorate_meta.decorator,
test_class_name,
base_test_name,
dtypes=decorate_meta.dtypes,
)
decorators.append(new_decorator)
opinfo.decorators = tuple(decorators)

# This decorator doesn't modify fn in any way
def wrapped(fn):
return fn

return wrapped


# Modify this section ##########################################################

# Ops to be tested for numerical consistency between onnx and pytorch
OPINFO_FUNCTION_MAPPING = {
"nn.functional.elu": nn_ops.aten_elu,
"nn.functional.relu6": nn_ops.aten_relu6,
"nn.functional.selu": core_ops.aten_selu,
}

TESTED_OPS = frozenset(OPINFO_FUNCTION_MAPPING)

EXPECTED_SKIPS_OR_FAILS = (
xfail(
"nn.functional.elu", dtypes=[torch.float64], reason="ORT does not support Elu float64"
),
)
# END OF SECTION TO MODIFY #####################################################


OPS_DB = copy.deepcopy(common_methods_invocations.op_db)


class TestOutputConsistency(unittest.TestCase):
"""Test output consistency between exported ONNX models and PyTorch eager mode.

This is a parameterized test suite.
"""

def setUp(self) -> None:
torch.manual_seed(42)
np.random.seed(42)

@common_device_type.ops(
[info for info in OPS_DB if info.name in TESTED_OPS],
allowed_dtypes=SUPPORTED_DTYPES,
)
@add_decorate_info(
OPS_DB,
"TestOutputConsistency",
"test_output_match",
skip_or_xfails=EXPECTED_SKIPS_OR_FAILS,
)
def test_output_match(self, device: str, dtype: torch.dtype, op):
"""Base test method for testing each opset, used by instantiate_device_type_tests."""
# device is provided by instantiate_device_type_tests, but we only want to run in cpu.
assert device == "cpu"

samples = op.sample_inputs(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why isn't there "shape" information for a sample input?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The shape is not explicitly used in this test because onnxscript can handle it in its evaluator. Any considerations?

device,
dtype,
requires_grad=False,
)

onnx_function = OPINFO_FUNCTION_MAPPING[op.name]
scripted_function = onnxscript.script()(onnx_function)

for (i, cpu_sample) in enumerate(samples):
inputs = (cpu_sample.input, *cpu_sample.args)
# Provide the repr to subtest because tensors are not serializable in parallel test runs
with self.subTest(
sample_num=i,
inputs=repr(inputs),
kwargs=repr(cpu_sample.kwargs),
):
input_numpy = [x.numpy() for x in inputs if isinstance(x, torch.Tensor)]
function_output = scripted_function(*input_numpy, **cpu_sample.kwargs)
torch_output = op(*inputs, **cpu_sample.kwargs)

np.testing.assert_allclose(
function_output,
torch_output.numpy(),
)


common_device_type.instantiate_device_type_tests(
TestOutputConsistency, globals(), only_for="cpu"
)


if __name__ == "__main__":
unittest.main()
4 changes: 4 additions & 0 deletions requirements-dev.txt
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,8 @@ sphinx-gallery
pydata_sphinx_theme

# ATen lib
typing_extensions
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you re-sort them in alphabetical order?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done

beartype
types-PyYAML

# Testing
Expand All @@ -22,6 +24,8 @@ pytest-subtests
pytest-xdist
parameterized
torch
expecttest
pyyaml

# Lint
lintrunner
Expand Down