Skip to content

Commit 4f2ab7d

Browse files
committed
feat(atenlib): add, sub, mul
ghstack-source-id: 89663f7 Pull Request resolved: #235
1 parent 96f8385 commit 4f2ab7d

File tree

2 files changed

+24
-4
lines changed

2 files changed

+24
-4
lines changed

onnxscript/function_libs/torch_aten/ops/core.py

Lines changed: 17 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -57,8 +57,9 @@ def aten_adaptive_max_pool1d(
5757

5858
def aten_add(self: TensorType, other: TensorType, alpha: float = 1) -> TensorType:
5959
# add.Tensor(Tensor self, Tensor other, *, Scalar alpha=1) -> Tensor
60-
61-
raise NotImplementedError()
60+
if alpha != 1:
61+
other = op.Mul(other, alpha)
62+
return op.Add(self, other)
6263

6364

6465
def aten_addbmm(
@@ -3107,7 +3108,16 @@ def aten_msort(self: TensorType) -> TensorType:
31073108
def aten_mul(self: TensorType, other: TensorType) -> TensorType:
31083109
# mul.Tensor(Tensor self, Tensor other) -> Tensor
31093110

3110-
raise NotImplementedError()
3111+
return op.Mul(self, other)
3112+
3113+
3114+
def aten_mul_bool(self: TensorType, other: TensorType) -> TensorType:
3115+
"""ONNX Mul doesn't support Boolean, so use And as an equivalent operator."""
3116+
3117+
# TODO(justinchuby): Handle cases where type reconcilation is not enough,
3118+
# since different ONNX operators are used based on different data types.
3119+
3120+
return op.And(self, other)
31113121

31123122

31133123
def aten_multinomial(
@@ -4337,7 +4347,10 @@ def aten_stft(
43374347
def aten_sub(self: TensorType, other: TensorType, alpha: float = 1) -> TensorType:
43384348
# sub.Tensor(Tensor self, Tensor other, *, Scalar alpha=1) -> Tensor
43394349

4340-
raise NotImplementedError()
4350+
if alpha != 1:
4351+
other = op.Mul(other, alpha)
4352+
4353+
return op.Sub(self, other)
43414354

43424355

43434356
def aten_subtract(self: TensorType, other: TensorType, alpha: float = 1) -> TensorType:

onnxscript/test/function_libs/torch_aten/ops_correctness_test.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
from typing import Callable, Collection, Iterable, Optional, Sequence
88

99
import numpy as np
10+
import onnxruntime.capi.onnxruntime_pybind11_state
1011
import torch
1112
from torch.testing._internal import common_device_type, common_methods_invocations
1213
from torch.testing._internal.opinfo import core as opinfo_core
@@ -156,14 +157,19 @@ def wrapped(fn):
156157

157158
# Ops to be tested for numerical consistency between onnx and pytorch
158159
OPINFO_FUNCTION_MAPPING = {
160+
"add": core_ops.aten_add,
161+
"mul": core_ops.aten_mul,
159162
"nn.functional.elu": nn_ops.aten_elu,
160163
"nn.functional.relu6": nn_ops.aten_relu6,
161164
"nn.functional.selu": core_ops.aten_selu,
165+
"sub": core_ops.aten_sub,
162166
}
163167

164168
TESTED_OPS = frozenset(OPINFO_FUNCTION_MAPPING)
165169

166170
EXPECTED_SKIPS_OR_FAILS = (
171+
xfail("add", dtypes=BOOL_TYPES, reason="Add is not defined on bool tensors"),
172+
xfail("mul", dtypes=BOOL_TYPES, reason="Mul is not defined on bool tensors"),
167173
xfail(
168174
"nn.functional.elu",
169175
dtypes=dtypes_except(torch.float16, torch.float32),
@@ -179,6 +185,7 @@ def wrapped(fn):
179185
dtypes=dtypes_except(torch.float16, torch.float32),
180186
reason="ONNX Runtime doesn't support float64 for Selu",
181187
),
188+
xfail("sub", dtypes=BOOL_TYPES, reason="Sub is not defined on bool tensors"),
182189
)
183190
# END OF SECTION TO MODIFY #####################################################
184191

0 commit comments

Comments
 (0)