Skip to content

Commit 11f5fe6

Browse files
committed
feat(atenlib): clamp, lt, gt
ghstack-source-id: fa8de17 Pull Request resolved: #247
1 parent f34e2b0 commit 11f5fe6

File tree

3 files changed

+91
-16
lines changed

3 files changed

+91
-16
lines changed
Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,13 @@
1+
"""Commonly shared functions for the function library."""
2+
from __future__ import annotations
3+
4+
import onnx.helper
5+
6+
from onnxscript.onnx_opset import opset18 as op
7+
8+
9+
def ones_like(x, onnx_dtype: int):
10+
shape = op.Shape(x)
11+
return op.ConstantOfShape(
12+
shape, value=onnx.helper.make_tensor("one", onnx_dtype, [1], [1])
13+
)

onnxscript/function_libs/torch_aten/ops/core.py

Lines changed: 47 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -17,8 +17,11 @@
1717

1818
from typing import Any, Optional, Sequence
1919

20+
import onnx.helper
21+
2022
from onnxscript import BOOL, INT64
21-
from onnxscript.onnx_opset import default_opset as op
23+
from onnxscript.function_libs.torch_aten.ops import common
24+
from onnxscript.onnx_opset import opset18 as op
2225
from onnxscript.onnx_types import TensorType
2326

2427

@@ -747,16 +750,31 @@ def aten_clamp(
747750
raise NotImplementedError()
748751

749752

750-
def aten_clamp_max(self: TensorType, max: float) -> TensorType:
753+
def aten_clamp_max_scalar(self, max_):
751754
# clamp_max(Tensor self, Scalar max) -> Tensor
752755

753-
raise NotImplementedError()
756+
max_ = op.CastLike(max_, self)
757+
return op.Clip(self, None, max_)
754758

755759

756-
def aten_clamp_min(self: TensorType, min: float) -> TensorType:
760+
def aten_clamp_max_tensor(self, max_):
761+
# clamp_max(Tensor self, Scalar max) -> Tensor
762+
763+
return op.Min(self, max_)
764+
765+
766+
def aten_clamp_min_scalar(self, min_):
757767
# clamp_min(Tensor self, Scalar min) -> Tensor
768+
# NOTE: min_ is a rank 0 tensor.
769+
# TODO(justinchuby): Specify the type constraints.
770+
min_ = op.CastLike(min_, self)
771+
return op.Clip(self, min_, None)
758772

759-
raise NotImplementedError()
773+
774+
def aten_clamp_min_tensor(self, min_):
775+
# clamp_min(Tensor self, Tensor min) -> Tensor
776+
# TODO(justinchuby): Specify the type constraints.
777+
return op.Max(self, min_)
760778

761779

762780
def aten_clip(
@@ -1958,10 +1976,12 @@ def aten_gru_cell(
19581976
raise NotImplementedError()
19591977

19601978

1961-
def aten_gt(self: TensorType, other: TensorType) -> TensorType:
1979+
def aten_gt(self, other):
19621980
# gt.Tensor(Tensor self, Tensor other) -> Tensor
19631981

1964-
raise NotImplementedError()
1982+
# TODO(justinchuby): Input spec: non bool tensor
1983+
# Boolean inputs can be pre-casted by policy
1984+
return op.Greater(self, other)
19651985

19661986

19671987
def aten_hamming_window(window_length: int) -> TensorType:
@@ -2572,10 +2592,12 @@ def aten_lstm_mps_backward(
25722592
raise NotImplementedError()
25732593

25742594

2575-
def aten_lt(self: TensorType, other: TensorType) -> TensorType:
2595+
def aten_lt(self, other):
25762596
# lt.Tensor(Tensor self, Tensor other) -> Tensor
25772597

2578-
raise NotImplementedError()
2598+
# TODO(justinchuby): Input spec: non bool tensor
2599+
# Boolean inputs can be pre-casted by policy
2600+
return op.Less(self, other)
25792601

25802602

25812603
def aten_lu_solve(self: TensorType, LU_data: TensorType, LU_pivots: TensorType) -> TensorType:
@@ -3440,10 +3462,17 @@ def aten_ones(size: INT64) -> TensorType:
34403462
raise NotImplementedError()
34413463

34423464

3443-
def aten_ones_like(self: TensorType, memory_format: Optional[str] = None) -> TensorType:
3465+
def aten_ones_like(self, dtype: Optional[int] = None):
3466+
"""ones_like.
3467+
3468+
Note: dtype is a torch enum. We need to convert it to ONNX dtype.
3469+
"""
34443470
# ones_like(Tensor self, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None, MemoryFormat? memory_format=None) -> Tensor
34453471

3446-
raise NotImplementedError()
3472+
# TODO(justinchuby): Create a helper to convert torch dtype to ONNX dtype
3473+
if dtype is None:
3474+
dtype = onnx.TensorProto.FLOAT
3475+
return common.ones_like(self, dtype)
34473476

34483477

34493478
def aten_or(self: TensorType, other: TensorType) -> TensorType:
@@ -3916,10 +3945,13 @@ def aten_renorm(self: TensorType, p: float, dim: int, maxnorm: float) -> TensorT
39163945
raise NotImplementedError()
39173946

39183947

3919-
def aten_repeat(self: TensorType, repeats: INT64) -> TensorType:
3948+
def aten_repeat(self, repeats: INT64):
39203949
# repeat(Tensor self, SymInt[] repeats) -> Tensor
39213950

3922-
raise NotImplementedError()
3951+
# FIXME(justinchuby): 'common' is not an instance of type Opset but <class 'module'>.
3952+
shape = common.ones_like(repeats, onnx.TensorProto.INT64)
3953+
expanded = op.Expand(self, shape)
3954+
return op.Tile(expanded, repeats)
39233955

39243956

39253957
def aten_repeat_interleave(
@@ -4012,10 +4044,10 @@ def aten_rot90(self: TensorType, k: int = 1, dims: Sequence[int] = (0, 1)) -> Te
40124044
raise NotImplementedError()
40134045

40144046

4015-
def aten_round(self: TensorType) -> TensorType:
4047+
def aten_round(self):
40164048
# round(Tensor self) -> Tensor
40174049

4018-
raise NotImplementedError()
4050+
return op.Round(self)
40194051

40204052

40214053
def aten_row_indices(self: TensorType) -> TensorType:

onnxscript/test/function_libs/torch_aten/ops_correctness_test.py

Lines changed: 31 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88

99
import numpy as np
1010
import onnxruntime.capi.onnxruntime_pybind11_state
11+
import parameterized
1112
import torch
1213
from torch.testing._internal import common_device_type, common_methods_invocations
1314
from torch.testing._internal.opinfo import core as opinfo_core
@@ -156,19 +157,26 @@ def wrapped(fn):
156157
# Modify this section ##########################################################
157158

158159
# Ops to be tested for numerical consistency between onnx and pytorch
159-
OPINFO_FUNCTION_MAPPING = {
160+
# Find the names of the OpInfos in torch/testing/_internal/common_methods_invocations.py
161+
OPINFO_FUNCTION_MAPPING: dict[str, Callable[..., Any]] = {
160162
"add": core_ops.aten_add,
163+
"gt": core_ops.aten_gt,
164+
"lt": core_ops.aten_lt,
161165
"mul": core_ops.aten_mul,
162166
"nn.functional.elu": nn_ops.aten_elu,
163167
"nn.functional.relu6": nn_ops.aten_relu6,
164168
"nn.functional.selu": core_ops.aten_selu,
169+
# "repeat": core_ops.aten_repeat,
170+
"round": core_ops.aten_round,
165171
"sub": core_ops.aten_sub,
166172
}
167173

168174
TESTED_OPS = frozenset(OPINFO_FUNCTION_MAPPING)
169175

170176
EXPECTED_SKIPS_OR_FAILS = (
171177
xfail("add", dtypes=BOOL_TYPES, reason="Add is not defined on bool tensors"),
178+
xfail("gt", dtypes=BOOL_TYPES, reason="Greater is not defined on bool tensors"),
179+
xfail("lt", dtypes=BOOL_TYPES, reason="Less is not defined on bool tensors"),
172180
xfail("mul", dtypes=BOOL_TYPES, reason="Mul is not defined on bool tensors"),
173181
xfail(
174182
"nn.functional.elu",
@@ -185,6 +193,17 @@ def wrapped(fn):
185193
dtypes=dtypes_except(torch.float16, torch.float32),
186194
reason="ONNX Runtime doesn't support float64 for Selu",
187195
),
196+
xfail(
197+
"round",
198+
variant_name="",
199+
dtypes=dtypes_except(*FLOAT_TYPES),
200+
reason="Round is not defined on non-float tensors",
201+
),
202+
xfail("round", variant_name="decimals_0", reason="The ATen op does not support decimals"),
203+
xfail("round", variant_name="decimals_3", reason="The ATen op does not support decimals"),
204+
xfail(
205+
"round", variant_name="decimals_neg_3", reason="The ATen op does not support decimals"
206+
),
188207
xfail("sub", dtypes=BOOL_TYPES, reason="Sub is not defined on bool tensors"),
189208
)
190209
# END OF SECTION TO MODIFY #####################################################
@@ -193,6 +212,17 @@ def wrapped(fn):
193212
OPS_DB = copy.deepcopy(common_methods_invocations.op_db)
194213

195214

215+
class TestFunctionsCompilation(unittest.TestCase):
216+
"""Test all functions can be compiled."""
217+
218+
@parameterized.parameterized.expand(
219+
list(OPINFO_FUNCTION_MAPPING.items()),
220+
)
221+
def test_function_compiles(self, _, function):
222+
compiled = onnxscript.script()(function)
223+
compiled.to_function_proto()
224+
225+
196226
class TestOutputConsistency(unittest.TestCase):
197227
"""Test output consistency between exported ONNX models and PyTorch eager mode.
198228

0 commit comments

Comments
 (0)