Skip to content

feat(atenlib): add, sub, mul #235

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 48 commits into from
Dec 7, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
48 commits
Select commit Hold shift + click to select a range
2c8ee88
feat(atenlib): establish the aten-lib directory
justinchuby Nov 22, 2022
40af0b4
Update base for Update on "feat(atenlib): establish the aten-lib dire…
justinchuby Nov 22, 2022
ee554d7
Update on "feat(atenlib): establish the aten-lib directory"
justinchuby Nov 22, 2022
fa8ce18
feat(atenlib): Create sample functions and tests
justinchuby Nov 23, 2022
6d426b5
Update base for Update on "feat(atenlib): Create sample functions and…
justinchuby Nov 23, 2022
bb9fbee
Update on "feat(atenlib): Create sample functions and tests with OpInfo"
justinchuby Nov 23, 2022
0aed28e
Update base for Update on "feat(atenlib): create tests with OpInfo"
justinchuby Nov 23, 2022
91aa327
Update on "feat(atenlib): create tests with OpInfo"
justinchuby Nov 23, 2022
00a081b
Update base for Update on "feat(atenlib): create tests with OpInfo"
justinchuby Nov 23, 2022
fce8072
Update on "feat(atenlib): create tests with OpInfo"
justinchuby Nov 23, 2022
af56008
Update base for Update on "feat(atenlib): create tests with OpInfo"
justinchuby Nov 23, 2022
8c0b370
Update on "feat(atenlib): create tests with OpInfo"
justinchuby Nov 23, 2022
bfeaefc
Update base for Update on "feat(atenlib): create tests with OpInfo"
justinchuby Nov 23, 2022
8dc3d15
Update on "feat(atenlib): create tests with OpInfo"
justinchuby Nov 23, 2022
46aa719
Update base for Update on "feat(atenlib): create tests with OpInfo"
justinchuby Nov 23, 2022
d87f309
Update on "feat(atenlib): create tests with OpInfo"
justinchuby Nov 23, 2022
f7555e5
Update base for Update on "feat(atenlib): create tests with OpInfo"
justinchuby Nov 29, 2022
c80ad0e
Update on "feat(atenlib): create tests with OpInfo"
justinchuby Nov 29, 2022
01ac14f
Update base for Update on "feat(atenlib): create tests with OpInfo"
justinchuby Nov 29, 2022
b730aa9
Update on "feat(atenlib): create tests with OpInfo"
justinchuby Nov 29, 2022
1d57934
Update base for Update on "feat(atenlib): create tests with OpInfo"
justinchuby Nov 29, 2022
407da68
Update on "feat(atenlib): create tests with OpInfo"
justinchuby Nov 29, 2022
2d1e9aa
Update base for Update on "feat(atenlib): create tests with OpInfo"
justinchuby Nov 30, 2022
4c3e9c2
Update on "feat(atenlib): create tests with OpInfo"
justinchuby Nov 30, 2022
cede59b
Update base for Update on "feat(atenlib): create tests with OpInfo"
justinchuby Nov 30, 2022
98bd90c
Update on "feat(atenlib): create tests with OpInfo"
justinchuby Nov 30, 2022
b7144a6
Update base for Update on "feat(atenlib): create tests with OpInfo"
justinchuby Nov 30, 2022
49fc7f1
Update on "feat(atenlib): create tests with OpInfo"
justinchuby Nov 30, 2022
891fbd1
Update base for Update on "feat(atenlib): create tests with OpInfo"
justinchuby Nov 30, 2022
63d4775
Update on "feat(atenlib): create tests with OpInfo"
justinchuby Nov 30, 2022
add75f1
Update base for Update on "feat(atenlib): create tests with OpInfo"
justinchuby Dec 5, 2022
2c38122
Update on "feat(atenlib): create tests with OpInfo"
justinchuby Dec 5, 2022
907a77a
Update base for Update on "feat(atenlib): create tests with OpInfo"
justinchuby Dec 5, 2022
d6dfd3d
Update on "feat(atenlib): create tests with OpInfo"
justinchuby Dec 5, 2022
035c0e2
Update base for Update on "feat(atenlib): create tests with OpInfo"
justinchuby Dec 6, 2022
69facc6
Update on "feat(atenlib): create tests with OpInfo"
justinchuby Dec 6, 2022
0c4e631
Update base for Update on "feat(atenlib): create tests with OpInfo"
justinchuby Dec 6, 2022
23771ab
Update on "feat(atenlib): create tests with OpInfo"
justinchuby Dec 6, 2022
e99fac7
feat(atenlib): add, sub, mul
justinchuby Dec 6, 2022
89e67ac
Update base for Update on "feat(atenlib): add, sub, mul"
justinchuby Dec 6, 2022
6b5c50b
Update on "feat(atenlib): add, sub, mul"
justinchuby Dec 6, 2022
e1ba56e
Update base for Update on "feat(atenlib): add, sub, mul"
justinchuby Dec 6, 2022
3f35a2c
Update on "feat(atenlib): add, sub, mul"
justinchuby Dec 6, 2022
fa9dde3
Update base for Update on "feat(atenlib): add, sub, mul"
justinchuby Dec 7, 2022
6c1baf4
Update on "feat(atenlib): add, sub, mul"
justinchuby Dec 7, 2022
c6b2cc4
Update base for Update on "feat(atenlib): add, sub, mul"
justinchuby Dec 7, 2022
0d13e27
Update on "feat(atenlib): add, sub, mul"
justinchuby Dec 7, 2022
1184f6b
Merge branch 'main' into gh/justinchuby/7/head
justinchuby Dec 7, 2022
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
29 changes: 21 additions & 8 deletions onnxscript/function_libs/torch_aten/ops/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@

from typing import Any, Optional, Sequence

from onnxscript import INT64
from onnxscript import BOOL, INT64
from onnxscript.onnx_opset import default_opset as op
from onnxscript.onnx_types import TensorType

Expand Down Expand Up @@ -60,10 +60,11 @@ def aten_adaptive_max_pool1d(
raise NotImplementedError()


def aten_add(self: TensorType, other: TensorType, alpha: float = 1) -> TensorType:
def aten_add(self, other, alpha: float = 1) -> TensorType:
# add.Tensor(Tensor self, Tensor other, *, Scalar alpha=1) -> Tensor

raise NotImplementedError()
if alpha != 1:
other = op.Mul(other, alpha) # type: ignore[arg-type]
return op.Add(self, other)


def aten_addbmm(
Expand Down Expand Up @@ -3109,10 +3110,19 @@ def aten_msort(self: TensorType) -> TensorType:
raise NotImplementedError()


def aten_mul(self: TensorType, other: TensorType) -> TensorType:
def aten_mul(self, other) -> TensorType:
# mul.Tensor(Tensor self, Tensor other) -> Tensor

raise NotImplementedError()
return op.Mul(self, other)


def aten_mul_bool(self: BOOL, other: BOOL) -> BOOL:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

BOOL[...] instead of BOOL ? Unless the type-annotation-design is going to change the meaning of BOOL, which is currently a tensor of rank-0

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yeah but mypy complains about BOOL[...]

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I’ll see if I can turn it off

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Couldn't make it to work. I will leave it as is for now as update when things fit together. Currently the types for the ops are also BOOL etc. I think?

"""ONNX Mul doesn't support Boolean, so use And as an equivalent operator."""

# TODO(justinchuby): Handle cases where type reconcilation is not enough,
# since different ONNX operators are used based on different data types.

return op.And(self, other)


def aten_multinomial(
Expand Down Expand Up @@ -4339,10 +4349,13 @@ def aten_stft(
raise NotImplementedError()


def aten_sub(self: TensorType, other: TensorType, alpha: float = 1) -> TensorType:
def aten_sub(self, other, alpha: float = 1) -> TensorType:
# sub.Tensor(Tensor self, Tensor other, *, Scalar alpha=1) -> Tensor

raise NotImplementedError()
if alpha != 1:
other = op.Mul(other, alpha) # type: ignore[arg-type]

return op.Sub(self, other)


def aten_subtract(self: TensorType, other: TensorType, alpha: float = 1) -> TensorType:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -157,14 +157,19 @@ def wrapped(fn):

# Ops to be tested for numerical consistency between onnx and pytorch
OPINFO_FUNCTION_MAPPING = {
"add": core_ops.aten_add,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For my own knowledge: what is core_ops a reference to? I see core.py above.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh, I see it's just an alias for the same

"mul": core_ops.aten_mul,
"nn.functional.elu": nn_ops.aten_elu,
"nn.functional.relu6": nn_ops.aten_relu6,
"nn.functional.selu": core_ops.aten_selu,
"sub": core_ops.aten_sub,
}

TESTED_OPS = frozenset(OPINFO_FUNCTION_MAPPING)

EXPECTED_SKIPS_OR_FAILS = (
xfail("add", dtypes=BOOL_TYPES, reason="Add is not defined on bool tensors"),
xfail("mul", dtypes=BOOL_TYPES, reason="Mul is not defined on bool tensors"),
xfail(
"nn.functional.elu",
dtypes=dtypes_except(torch.float16, torch.float32),
Expand All @@ -180,6 +185,7 @@ def wrapped(fn):
dtypes=dtypes_except(torch.float16, torch.float32),
reason="ONNX Runtime doesn't support float64 for Selu",
),
xfail("sub", dtypes=BOOL_TYPES, reason="Sub is not defined on bool tensors"),
)
# END OF SECTION TO MODIFY #####################################################

Expand Down