Skip to content
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
Show all changes
23 commits
Select commit Hold shift + click to select a range
55f8dd9
fix: annotate script()
justinchuby Dec 9, 2022
8a3a587
feat(atenlib): clamp, lt, gt
justinchuby Dec 9, 2022
f821b6a
Update on "feat(atenlib): implement aten functions 1/n"
justinchuby Dec 9, 2022
aecc148
Update on "feat(atenlib): implement aten functions 1/n"
justinchuby Dec 10, 2022
6555a55
Update on "feat(atenlib): implement aten functions 1/n"
justinchuby Dec 10, 2022
f8385b0
Update base for Update on "feat(atenlib): implement aten functions 1/n"
justinchuby Dec 10, 2022
468f86f
Update on "feat(atenlib): implement aten functions 1/n"
justinchuby Dec 10, 2022
47b8380
Update base for Update on "feat(atenlib): implement aten functions 1/n"
justinchuby Dec 10, 2022
00f1760
Update on "feat(atenlib): implement aten functions 1/n"
justinchuby Dec 10, 2022
060f9db
Update base for Update on "feat(atenlib): implement aten functions 1/n"
justinchuby Dec 10, 2022
497cb16
Update on "feat(atenlib): implement aten functions 1/n"
justinchuby Dec 10, 2022
9bb4038
Update base for Update on "feat(atenlib): implement aten functions 1/n"
justinchuby Dec 12, 2022
d24110a
Update on "feat(atenlib): implement aten functions 1/n"
justinchuby Dec 12, 2022
cbfb867
Update base for Update on "feat(atenlib): implement aten functions 1/n"
justinchuby Dec 12, 2022
875f235
Update on "feat(atenlib): implement aten functions 1/n"
justinchuby Dec 12, 2022
27008e1
Update base for Update on "feat(atenlib): implement aten functions 1/n"
justinchuby Dec 12, 2022
3a8737d
Update on "feat(atenlib): implement aten functions 1/n"
justinchuby Dec 12, 2022
c5871c8
Update base for Update on "feat(atenlib): implement aten functions 1/n"
justinchuby Dec 12, 2022
012905c
Update on "feat(atenlib): implement aten functions 1/n"
justinchuby Dec 12, 2022
49be5ec
Update base for Update on "feat(atenlib): implement aten functions 1/n"
justinchuby Dec 13, 2022
3a9c5f6
Update on "feat(atenlib): implement aten functions 1/n"
justinchuby Dec 13, 2022
d4f09e8
Update base for Update on "feat(atenlib): implement aten functions 1/n"
justinchuby Dec 13, 2022
ee3143e
Update on "feat(atenlib): implement aten functions 1/n"
justinchuby Dec 13, 2022
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 12 additions & 0 deletions onnxscript/function_libs/torch_aten/ops/common.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
"""Commonly shared functions for the function library."""
from __future__ import annotations

import onnxscript
from onnxscript.onnx_opset import opset18 as op


@onnxscript.script()
def ones_like(x, dtype: int):
shape = op.Shape(x)
one_dtype = op.Cast(1, to=dtype)
return op.Expand(one_dtype, shape)
69 changes: 54 additions & 15 deletions onnxscript/function_libs/torch_aten/ops/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,11 +17,24 @@

from typing import Any, Optional, Sequence

import onnx.helper

import onnxscript
from onnxscript import BOOL, INT64
from onnxscript.onnx_opset import default_opset as op
from onnxscript.function_libs.torch_aten.ops import common
from onnxscript.onnx_opset import opset18 as op
from onnxscript.onnx_types import TensorType


@onnxscript.script()
def _ones_like(x, dtype: int):
"""Common function for ones_like."""
# TODO(justinchuby): Put this in another module
shape = op.Shape(x)
one_dtype = op.Cast(1, to=dtype)
return op.Expand(one_dtype, shape)


def aten_abs(self: TensorType) -> TensorType:
# abs(Tensor self) -> Tensor

Expand Down Expand Up @@ -747,16 +760,31 @@ def aten_clamp(
raise NotImplementedError()


def aten_clamp_max(self: TensorType, max: float) -> TensorType:
def aten_clamp_max_scalar(self, max_):
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is it possible to validate the shape of max_ to know it is a scalar or tensor instead of having such information in the function name?

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I haven't thought of a good way. Any suggestions?

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah, for validation, we can specify it in the signature / via some data structure. We can discuss today.

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

They still need to be two functions though.

# clamp_max(Tensor self, Scalar max) -> Tensor

raise NotImplementedError()
max_ = op.CastLike(max_, self)
return op.Clip(self, None, max_)


def aten_clamp_max_tensor(self, max_):
# clamp_max(Tensor self, Scalar max) -> Tensor

return op.Min(self, max_)


def aten_clamp_min(self: TensorType, min: float) -> TensorType:
def aten_clamp_min_scalar(self, min_):
# clamp_min(Tensor self, Scalar min) -> Tensor
# NOTE: min_ is a rank 0 tensor.
# TODO(justinchuby): Specify the type constraints.
min_ = op.CastLike(min_, self)
return op.Clip(self, min_, None)

raise NotImplementedError()

def aten_clamp_min_tensor(self, min_):
# clamp_min(Tensor self, Tensor min) -> Tensor
# TODO(justinchuby): Specify the type constraints.
return op.Max(self, min_)


def aten_clip(
Expand Down Expand Up @@ -1958,10 +1986,12 @@ def aten_gru_cell(
raise NotImplementedError()


def aten_gt(self: TensorType, other: TensorType) -> TensorType:
def aten_gt(self, other):
# gt.Tensor(Tensor self, Tensor other) -> Tensor

raise NotImplementedError()
# TODO(justinchuby): Input spec: non bool tensor
# Boolean inputs can be pre-casted by policy
return op.Greater(self, other)


def aten_hamming_window(window_length: int) -> TensorType:
Expand Down Expand Up @@ -2572,10 +2602,12 @@ def aten_lstm_mps_backward(
raise NotImplementedError()


def aten_lt(self: TensorType, other: TensorType) -> TensorType:
def aten_lt(self, other):
# lt.Tensor(Tensor self, Tensor other) -> Tensor

raise NotImplementedError()
# TODO(justinchuby): Input spec: non bool tensor
# Boolean inputs can be pre-casted by policy
return op.Less(self, other)


def aten_lu_solve(self: TensorType, LU_data: TensorType, LU_pivots: TensorType) -> TensorType:
Expand Down Expand Up @@ -3440,10 +3472,15 @@ def aten_ones(size: INT64) -> TensorType:
raise NotImplementedError()


def aten_ones_like(self: TensorType, memory_format: Optional[str] = None) -> TensorType:
def aten_ones_like(self, dtype: int = onnx.TensorProto.FLOAT):
"""ones_like.

Note: dtype is an onnx enum. Users should convert torch dtype to onnx dtype
before calling this function.
"""
# ones_like(Tensor self, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None, MemoryFormat? memory_format=None) -> Tensor

raise NotImplementedError()
return common.ones_like(self, dtype)


def aten_or(self: TensorType, other: TensorType) -> TensorType:
Expand Down Expand Up @@ -3916,10 +3953,12 @@ def aten_renorm(self: TensorType, p: float, dim: int, maxnorm: float) -> TensorT
raise NotImplementedError()


def aten_repeat(self: TensorType, repeats: INT64) -> TensorType:
def aten_repeat(self, repeats: INT64):
# repeat(Tensor self, SymInt[] repeats) -> Tensor

raise NotImplementedError()
shape = _ones_like(repeats, onnx.TensorProto.INT64)
expanded = op.Expand(self, shape)
return op.Tile(expanded, repeats)


def aten_repeat_interleave(
Expand Down Expand Up @@ -4012,10 +4051,10 @@ def aten_rot90(self: TensorType, k: int = 1, dims: Sequence[int] = (0, 1)) -> Te
raise NotImplementedError()


def aten_round(self: TensorType) -> TensorType:
def aten_round(self):
# round(Tensor self) -> Tensor

raise NotImplementedError()
return op.Round(self)


def aten_row_indices(self: TensorType) -> TensorType:
Expand Down
5 changes: 4 additions & 1 deletion onnxscript/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
import inspect
import sys
import textwrap
from typing import Any, Callable, Optional

import onnx.helper

Expand Down Expand Up @@ -52,7 +53,9 @@ def script_check(f: ast.FunctionDef, opset, global_names, source, default_opset=
return convert.top_level_stmt(f)


def script(opset=None, default_opset=None, **kwargs):
def script(
opset: Optional[values.Opset] = None, default_opset: Optional[Any] = None, **kwargs: Any
) -> Callable[[Callable[..., Any]], onnxscript.OnnxFunction]:
"""Main decorator. Declares a function as an onnx function.

Args:
Expand Down
32 changes: 31 additions & 1 deletion onnxscript/test/function_libs/torch_aten/ops_correctness_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@

import numpy as np
import onnxruntime.capi.onnxruntime_pybind11_state
import parameterized
import torch
from torch.testing._internal import common_device_type, common_methods_invocations
from torch.testing._internal.opinfo import core as opinfo_core
Expand Down Expand Up @@ -156,19 +157,26 @@ def wrapped(fn):
# Modify this section ##########################################################

# Ops to be tested for numerical consistency between onnx and pytorch
OPINFO_FUNCTION_MAPPING = {
# Find the names of the OpInfos in torch/testing/_internal/common_methods_invocations.py
OPINFO_FUNCTION_MAPPING: dict[str, Callable[..., Any]] = {
"add": core_ops.aten_add,
"gt": core_ops.aten_gt,
"lt": core_ops.aten_lt,
"mul": core_ops.aten_mul,
"nn.functional.elu": nn_ops.aten_elu,
"nn.functional.relu6": nn_ops.aten_relu6,
"nn.functional.selu": core_ops.aten_selu,
"repeat": core_ops.aten_repeat,
"round": core_ops.aten_round,
"sub": core_ops.aten_sub,
}

TESTED_OPS = frozenset(OPINFO_FUNCTION_MAPPING)

EXPECTED_SKIPS_OR_FAILS = (
xfail("add", dtypes=BOOL_TYPES, reason="Add is not defined on bool tensors"),
xfail("gt", dtypes=BOOL_TYPES, reason="Greater is not defined on bool tensors"),
xfail("lt", dtypes=BOOL_TYPES, reason="Less is not defined on bool tensors"),
xfail("mul", dtypes=BOOL_TYPES, reason="Mul is not defined on bool tensors"),
xfail(
"nn.functional.elu",
Expand All @@ -185,6 +193,17 @@ def wrapped(fn):
dtypes=dtypes_except(torch.float16, torch.float32),
reason="ONNX Runtime doesn't support float64 for Selu",
),
xfail(
"round",
variant_name="",
dtypes=dtypes_except(*FLOAT_TYPES),
reason="Round is not defined on non-float tensors",
),
xfail("round", variant_name="decimals_0", reason="The ATen op does not support decimals"),
xfail("round", variant_name="decimals_3", reason="The ATen op does not support decimals"),
xfail(
"round", variant_name="decimals_neg_3", reason="The ATen op does not support decimals"
),
xfail("sub", dtypes=BOOL_TYPES, reason="Sub is not defined on bool tensors"),
)
# END OF SECTION TO MODIFY #####################################################
Expand All @@ -193,6 +212,17 @@ def wrapped(fn):
OPS_DB = copy.deepcopy(common_methods_invocations.op_db)


class TestFunctionsCompilation(unittest.TestCase):
"""Test all functions can be compiled."""

@parameterized.parameterized.expand(
list(OPINFO_FUNCTION_MAPPING.items()),
)
def test_function_compiles(self, _, function):
compiled = onnxscript.script()(function)
compiled.to_function_proto()


class TestOutputConsistency(unittest.TestCase):
"""Test output consistency between exported ONNX models and PyTorch eager mode.

Expand Down