Skip to content

Commit a7ab057

Browse files
committed
feat(atenlib): ops 3/n
A bunch of matmul ops ghstack-source-id: da6585a Pull Request resolved: #255
1 parent 745519d commit a7ab057

File tree

2 files changed

+32
-18
lines changed

2 files changed

+32
-18
lines changed

onnxscript/function_libs/torch_aten/ops/core.py

Lines changed: 13 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -22,28 +22,22 @@
2222
from onnxscript.onnx_types import TensorType
2323

2424

25-
def aten_abs(self: TensorType) -> TensorType:
25+
def aten_abs(self):
2626
# abs(Tensor self) -> Tensor
2727

28-
raise NotImplementedError()
29-
30-
31-
def aten_absolute(self: TensorType) -> TensorType:
32-
# absolute(Tensor self) -> Tensor
33-
34-
raise NotImplementedError()
28+
return op.Abs(self)
3529

3630

37-
def aten_acos(self: TensorType) -> TensorType:
31+
def aten_acos(self):
3832
# acos(Tensor self) -> Tensor
3933

40-
raise NotImplementedError()
34+
return op.Acos(self)
4135

4236

43-
def aten_acosh(self: TensorType) -> TensorType:
37+
def aten_acosh(self):
4438
# acosh(Tensor self) -> Tensor
4539

46-
raise NotImplementedError()
40+
return op.Acosh(self)
4741

4842

4943
def aten_adaptive_avg_pool1d(self: TensorType, output_size: Sequence[int]) -> TensorType:
@@ -91,12 +85,13 @@ def aten_addcmul(
9185
raise NotImplementedError()
9286

9387

94-
def aten_addmm(
95-
self: TensorType, mat1: TensorType, mat2: TensorType, beta: float = 1, alpha: float = 1
96-
) -> TensorType:
88+
def aten_addmm(self, mat1, mat2, beta: float = 1, alpha: float = 1):
9789
# addmm(Tensor self, Tensor mat1, Tensor mat2, *, Scalar beta=1, Scalar alpha=1) -> Tensor
9890

99-
raise NotImplementedError()
91+
mat1_mat2 = op.MatMul(mat1, mat2)
92+
scaled_mat1_mat2 = op.Mul(mat1_mat2, alpha) # type: ignore[arg-type]
93+
scaled_self = op.Mul(self, beta) # type: ignore[arg-type]
94+
return op.Add(scaled_self, scaled_mat1_mat2) # type: ignore[arg-type]
10095

10196

10297
def aten_addmv(
@@ -611,10 +606,10 @@ def aten_block_diag(tensors: Sequence[TensorType]) -> TensorType:
611606
raise NotImplementedError()
612607

613608

614-
def aten_bmm(self: TensorType, mat2: TensorType) -> TensorType:
609+
def aten_bmm(self, mat2):
615610
# bmm(Tensor self, Tensor mat2) -> Tensor
616611

617-
raise NotImplementedError()
612+
return op.MatMul(self, mat2)
618613

619614

620615
def aten_broadcast_tensors(tensors: Sequence[TensorType]) -> TensorType:

onnxscript/test/function_libs/torch_aten/ops_correctness_test.py

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -161,7 +161,10 @@ def wrapped(fn):
161161
# Ops to be tested for numerical consistency between onnx and pytorch
162162
# Find the names of the OpInfos in torch/testing/_internal/common_methods_invocations.py
163163
OPINFO_FUNCTION_MAPPING: dict[str, Callable[..., Any]] = {
164+
"abs": core_ops.aten_abs,
164165
"add": core_ops.aten_add,
166+
"addmm": core_ops.aten_addmm,
167+
"bmm": core_ops.aten_bmm,
165168
"clamp_max": core_ops.aten_clamp_max_tensor,
166169
"clamp_min": core_ops.aten_clamp_min_tensor,
167170
"clamp": core_ops.aten_clamp,
@@ -187,6 +190,22 @@ def wrapped(fn):
187190

188191
EXPECTED_SKIPS_OR_FAILS = (
189192
xfail("add", dtypes=BOOL_TYPES, reason="Add is not defined on bool tensors"),
193+
xfail(
194+
"addmm",
195+
dtypes=[torch.uint8, torch.int8, torch.int16],
196+
reason="MatMul is not defined on int16/int8/uint8 tensors",
197+
),
198+
xfail(
199+
"addmm",
200+
variant_name="decomposed",
201+
dtypes=[torch.uint8, torch.int8, torch.int16],
202+
reason="MatMul is not defined on int16/int8/uint8 tensors",
203+
),
204+
xfail(
205+
"bmm",
206+
dtypes=[torch.uint8, torch.int8, torch.int16],
207+
reason="MatMul is not defined on int16/int8/uint8 tensors",
208+
),
190209
skip("clamp", reason="Enable when onnxscript errors are fixed"),
191210
xfail("clamp_max", dtypes=BOOL_TYPES, reason="Min is not defined on bool tensors"),
192211
xfail("clamp_min", dtypes=BOOL_TYPES, reason="Max is not defined on bool tensors"),

0 commit comments

Comments
 (0)