Skip to content

Commit f802167

Browse files
authored
feat(atenlib): add, sub, mul (#235)
Stack from [ghstack](https://github.com/ezyang/ghstack) (oldest at bottom): * __->__ #235
1 parent c9b3ac6 commit f802167

File tree

2 files changed

+27
-8
lines changed

2 files changed

+27
-8
lines changed

onnxscript/function_libs/torch_aten/ops/core.py

Lines changed: 21 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717

1818
from typing import Any, Optional, Sequence
1919

20-
from onnxscript import INT64
20+
from onnxscript import BOOL, INT64
2121
from onnxscript.onnx_opset import default_opset as op
2222
from onnxscript.onnx_types import TensorType
2323

@@ -60,10 +60,11 @@ def aten_adaptive_max_pool1d(
6060
raise NotImplementedError()
6161

6262

63-
def aten_add(self: TensorType, other: TensorType, alpha: float = 1) -> TensorType:
63+
def aten_add(self, other, alpha: float = 1) -> TensorType:
6464
# add.Tensor(Tensor self, Tensor other, *, Scalar alpha=1) -> Tensor
65-
66-
raise NotImplementedError()
65+
if alpha != 1:
66+
other = op.Mul(other, alpha) # type: ignore[arg-type]
67+
return op.Add(self, other)
6768

6869

6970
def aten_addbmm(
@@ -3109,10 +3110,19 @@ def aten_msort(self: TensorType) -> TensorType:
31093110
raise NotImplementedError()
31103111

31113112

3112-
def aten_mul(self: TensorType, other: TensorType) -> TensorType:
3113+
def aten_mul(self, other) -> TensorType:
31133114
# mul.Tensor(Tensor self, Tensor other) -> Tensor
31143115

3115-
raise NotImplementedError()
3116+
return op.Mul(self, other)
3117+
3118+
3119+
def aten_mul_bool(self: BOOL, other: BOOL) -> BOOL:
3120+
"""ONNX Mul doesn't support Boolean, so use And as an equivalent operator."""
3121+
3122+
# TODO(justinchuby): Handle cases where type reconcilation is not enough,
3123+
# since different ONNX operators are used based on different data types.
3124+
3125+
return op.And(self, other)
31163126

31173127

31183128
def aten_multinomial(
@@ -4339,10 +4349,13 @@ def aten_stft(
43394349
raise NotImplementedError()
43404350

43414351

4342-
def aten_sub(self: TensorType, other: TensorType, alpha: float = 1) -> TensorType:
4352+
def aten_sub(self, other, alpha: float = 1) -> TensorType:
43434353
# sub.Tensor(Tensor self, Tensor other, *, Scalar alpha=1) -> Tensor
43444354

4345-
raise NotImplementedError()
4355+
if alpha != 1:
4356+
other = op.Mul(other, alpha) # type: ignore[arg-type]
4357+
4358+
return op.Sub(self, other)
43464359

43474360

43484361
def aten_subtract(self: TensorType, other: TensorType, alpha: float = 1) -> TensorType:

onnxscript/test/function_libs/torch_aten/ops_correctness_test.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -157,14 +157,19 @@ def wrapped(fn):
157157

158158
# Ops to be tested for numerical consistency between onnx and pytorch
159159
OPINFO_FUNCTION_MAPPING = {
160+
"add": core_ops.aten_add,
161+
"mul": core_ops.aten_mul,
160162
"nn.functional.elu": nn_ops.aten_elu,
161163
"nn.functional.relu6": nn_ops.aten_relu6,
162164
"nn.functional.selu": core_ops.aten_selu,
165+
"sub": core_ops.aten_sub,
163166
}
164167

165168
TESTED_OPS = frozenset(OPINFO_FUNCTION_MAPPING)
166169

167170
EXPECTED_SKIPS_OR_FAILS = (
171+
xfail("add", dtypes=BOOL_TYPES, reason="Add is not defined on bool tensors"),
172+
xfail("mul", dtypes=BOOL_TYPES, reason="Mul is not defined on bool tensors"),
168173
xfail(
169174
"nn.functional.elu",
170175
dtypes=dtypes_except(torch.float16, torch.float32),
@@ -180,6 +185,7 @@ def wrapped(fn):
180185
dtypes=dtypes_except(torch.float16, torch.float32),
181186
reason="ONNX Runtime doesn't support float64 for Selu",
182187
),
188+
xfail("sub", dtypes=BOOL_TYPES, reason="Sub is not defined on bool tensors"),
183189
)
184190
# END OF SECTION TO MODIFY #####################################################
185191

0 commit comments

Comments
 (0)