Skip to content

Commit 1082a20

Browse files
committed
feat(atenlib): clamp, lt, gt
ghstack-source-id: 6f8c1d4 Pull Request resolved: #247
1 parent f7b9f34 commit 1082a20

File tree

2 files changed

+194
-24
lines changed

2 files changed

+194
-24
lines changed

onnxscript/function_libs/torch_aten/ops/core.py

Lines changed: 57 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,8 @@
1717

1818
from typing import Any, Optional, Sequence
1919

20+
import onnx
21+
2022
from onnxscript import BOOL, INT64
2123
from onnxscript.onnx_opset import opset18 as op
2224
from onnxscript.onnx_types import TensorType
@@ -747,16 +749,31 @@ def aten_clamp(
747749
raise NotImplementedError()
748750

749751

750-
def aten_clamp_max(self: TensorType, max: float) -> TensorType:
752+
def aten_clamp_max_scalar(self, max_):
751753
# clamp_max(Tensor self, Scalar max) -> Tensor
752754

753-
raise NotImplementedError()
755+
max_ = op.CastLike(max_, self)
756+
return op.Clip(self, None, max_)
757+
758+
759+
def aten_clamp_max_tensor(self, max_):
760+
# clamp_max(Tensor self, Scalar max) -> Tensor
754761

762+
return op.Min(self, max_)
755763

756-
def aten_clamp_min(self: TensorType, min: float) -> TensorType:
764+
765+
def aten_clamp_min_scalar(self, min_):
757766
# clamp_min(Tensor self, Scalar min) -> Tensor
767+
# NOTE: min_ is a rank 0 tensor.
768+
# TODO(justinchuby): Specify the type constraints.
769+
min_ = op.CastLike(min_, self)
770+
return op.Clip(self, min_, None)
758771

759-
raise NotImplementedError()
772+
773+
def aten_clamp_min_tensor(self, min_):
774+
# clamp_min(Tensor self, Tensor min) -> Tensor
775+
# TODO(justinchuby): Specify the type constraints.
776+
return op.Max(self, min_)
760777

761778

762779
def aten_clip(
@@ -1958,10 +1975,12 @@ def aten_gru_cell(
19581975
raise NotImplementedError()
19591976

19601977

1961-
def aten_gt(self: TensorType, other: TensorType) -> TensorType:
1978+
def aten_gt(self, other):
19621979
# gt.Tensor(Tensor self, Tensor other) -> Tensor
19631980

1964-
raise NotImplementedError()
1981+
# TODO(justinchuby): Input spec: non bool tensor
1982+
# Boolean inputs can be pre-casted by policy
1983+
return op.Greater(self, other)
19651984

19661985

19671986
def aten_hamming_window(window_length: int) -> TensorType:
@@ -2572,10 +2591,12 @@ def aten_lstm_mps_backward(
25722591
raise NotImplementedError()
25732592

25742593

2575-
def aten_lt(self: TensorType, other: TensorType) -> TensorType:
2594+
def aten_lt(self, other):
25762595
# lt.Tensor(Tensor self, Tensor other) -> Tensor
25772596

2578-
raise NotImplementedError()
2597+
# TODO(justinchuby): Input spec: non bool tensor
2598+
# Boolean inputs can be pre-casted by policy
2599+
return op.Less(self, other)
25792600

25802601

25812602
def aten_lu_solve(self: TensorType, LU_data: TensorType, LU_pivots: TensorType) -> TensorType:
@@ -3440,10 +3461,23 @@ def aten_ones(size: INT64) -> TensorType:
34403461
raise NotImplementedError()
34413462

34423463

3443-
def aten_ones_like(self: TensorType, memory_format: Optional[str] = None) -> TensorType:
3464+
def aten_ones_like_dtype(self, dtype: int = onnx.TensorProto.FLOAT):
3465+
"""ones_like.
3466+
3467+
Note: dtype is an onnx enum. Users should convert torch dtype to onnx dtype
3468+
before calling this function.
3469+
"""
34443470
# ones_like(Tensor self, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None, MemoryFormat? memory_format=None) -> Tensor
34453471

3446-
raise NotImplementedError()
3472+
shape = op.Shape(self)
3473+
one = op.Cast(1, to=dtype) # type: ignore[arg-type]
3474+
return op.Expand(one, shape)
3475+
3476+
3477+
def aten_one_like(self):
3478+
shape = op.Shape(self)
3479+
one = op.CastLike(1, self) # type: ignore[arg-type]
3480+
return op.Expand(one, shape)
34473481

34483482

34493483
def aten_or(self: TensorType, other: TensorType) -> TensorType:
@@ -3916,10 +3950,19 @@ def aten_renorm(self: TensorType, p: float, dim: int, maxnorm: float) -> TensorT
39163950
raise NotImplementedError()
39173951

39183952

3919-
def aten_repeat(self: TensorType, repeats: INT64) -> TensorType:
3953+
def aten_repeat(self, repeats: INT64):
39203954
# repeat(Tensor self, SymInt[] repeats) -> Tensor
39213955

3922-
raise NotImplementedError()
3956+
# FIXME(justinchuby): When repeats.shape == [0]
3957+
3958+
# TODO(justinchuby): Make ones_like a function when onnxscript supports it
3959+
# shape = ones_like(repeats) := {
3960+
one = op.Constant(value_int=1)
3961+
repeats_shape = op.Shape(repeats)
3962+
shape = op.Expand(one, repeats_shape)
3963+
# }
3964+
self_expanded = op.Expand(self, shape) # type: ignore[arg-type]
3965+
return op.Tile(self_expanded, repeats)
39233966

39243967

39253968
def aten_repeat_interleave(
@@ -4012,10 +4055,10 @@ def aten_rot90(self: TensorType, k: int = 1, dims: Sequence[int] = (0, 1)) -> Te
40124055
raise NotImplementedError()
40134056

40144057

4015-
def aten_round(self: TensorType) -> TensorType:
4058+
def aten_round(self):
40164059
# round(Tensor self) -> Tensor
40174060

4018-
raise NotImplementedError()
4061+
return op.Round(self)
40194062

40204063

40214064
def aten_row_indices(self: TensorType) -> TensorType:

onnxscript/test/function_libs/torch_aten/ops_correctness_test.py

Lines changed: 137 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,9 @@
77
from typing import Any, Callable, Collection, Iterable, Optional, Sequence, TypeVar
88

99
import numpy as np
10+
import onnx
1011
import onnxruntime.capi.onnxruntime_pybind11_state
12+
import parameterized
1113
import torch
1214
from torch.testing._internal import common_device_type, common_methods_invocations
1315
from torch.testing._internal.opinfo import core as opinfo_core
@@ -69,14 +71,15 @@ class DecorateMeta:
6971
decorator: Callable[..., Any]
7072
dtypes: Optional[Collection[torch.dtype]]
7173
reason: str
74+
matcher: Optional[Callable[[Any], bool]] = None
7275

7376

7477
def xfail(
7578
op_name: str,
7679
variant_name: str = "",
7780
*,
81+
reason: str,
7882
dtypes: Optional[Collection[torch.dtype]] = None,
79-
reason: Optional[str] = None,
8083
):
8184
"""Expects an OpInfo test to fail.
8285
@@ -86,8 +89,6 @@ def xfail(
8689
dtypes: The dtypes to expect the failure.
8790
reason: The reason for the failure.
8891
"""
89-
if reason is None:
90-
raise ValueError("Please specify a reason.")
9192
return DecorateMeta(
9293
op_name=op_name,
9394
variant_name=variant_name,
@@ -101,8 +102,9 @@ def skip(
101102
op_name: str,
102103
variant_name: str = "",
103104
*,
105+
reason: str,
104106
dtypes: Optional[Collection[torch.dtype]] = None,
105-
reason: Optional[str] = None,
107+
matcher: Optional[Callable[[Any], Any]] = None,
106108
):
107109
"""Skips an OpInfo test.
108110
@@ -111,15 +113,16 @@ def skip(
111113
variant_name: Optional OpInfo variant_test_name.
112114
dtypes: The dtypes to skip.
113115
reason: The reason for skipping.
116+
matcher: A function that matches the test sample input. It is used only when
117+
xfail is in the SKIP_SUBTESTS list.
114118
"""
115-
if reason is None:
116-
raise ValueError("Please specify a reason.")
117119
return DecorateMeta(
118120
op_name=op_name,
119121
variant_name=variant_name,
120122
decorator=unittest.skip(f"Don't care: {reason}"),
121123
dtypes=dtypes,
122124
reason=reason,
125+
matcher=matcher,
123126
)
124127

125128

@@ -159,17 +162,28 @@ def wrapped(fn):
159162
# Find the names of the OpInfos in torch/testing/_internal/common_methods_invocations.py
160163
OPINFO_FUNCTION_MAPPING: dict[str, Callable[..., Any]] = {
161164
"add": core_ops.aten_add,
165+
"clamp_max": core_ops.aten_clamp_max_tensor,
166+
"clamp_min": core_ops.aten_clamp_min_tensor,
167+
"gt": core_ops.aten_gt,
168+
"lt": core_ops.aten_lt,
162169
"mul": core_ops.aten_mul,
163170
"nn.functional.elu": nn_ops.aten_elu,
164171
"nn.functional.relu6": nn_ops.aten_relu6,
165172
"nn.functional.selu": core_ops.aten_selu,
173+
"ones_like": core_ops.aten_ones_like_dtype,
174+
"repeat": core_ops.aten_repeat,
175+
"round": core_ops.aten_round,
166176
"sub": core_ops.aten_sub,
167177
}
168178

169179
TESTED_OPS = frozenset(OPINFO_FUNCTION_MAPPING)
170180

171181
EXPECTED_SKIPS_OR_FAILS = (
172182
xfail("add", dtypes=BOOL_TYPES, reason="Add is not defined on bool tensors"),
183+
xfail("clamp_max", dtypes=BOOL_TYPES, reason="Min is not defined on bool tensors"),
184+
xfail("clamp_min", dtypes=BOOL_TYPES, reason="Max is not defined on bool tensors"),
185+
xfail("gt", dtypes=BOOL_TYPES, reason="Greater is not defined on bool tensors"),
186+
xfail("lt", dtypes=BOOL_TYPES, reason="Less is not defined on bool tensors"),
173187
xfail("mul", dtypes=BOOL_TYPES, reason="Mul is not defined on bool tensors"),
174188
xfail(
175189
"nn.functional.elu",
@@ -186,14 +200,123 @@ def wrapped(fn):
186200
dtypes=dtypes_except(torch.float16, torch.float32),
187201
reason="ONNX Runtime doesn't support float64 for Selu",
188202
),
203+
xfail(
204+
"round",
205+
variant_name="",
206+
dtypes=dtypes_except(*FLOAT_TYPES),
207+
reason="Round is not defined on non-float tensors",
208+
),
209+
xfail("round", variant_name="decimals_0", reason="The ATen op does not support decimals"),
210+
xfail("round", variant_name="decimals_3", reason="The ATen op does not support decimals"),
211+
xfail(
212+
"round", variant_name="decimals_neg_3", reason="The ATen op does not support decimals"
213+
),
189214
xfail("sub", dtypes=BOOL_TYPES, reason="Sub is not defined on bool tensors"),
190215
)
216+
217+
218+
SKIP_SUBTESTS = (
219+
skip(
220+
"clamp_max",
221+
reason="Empty tensor not yet supported",
222+
matcher=lambda sample: sample.input.size() == torch.Size([0]),
223+
),
224+
skip(
225+
"clamp_min",
226+
reason="Empty tensor not yet supported",
227+
matcher=lambda sample: sample.input.size() == torch.Size([0]),
228+
),
229+
skip(
230+
"repeat",
231+
reason="repeating when input is a scalar and repeats is empty is not supported",
232+
matcher=lambda sample: sample.args[0] == (),
233+
),
234+
skip(
235+
"ones_like",
236+
# TODO(justinchuby): Test aten_ones_like
237+
reason="dtype must be provided for aten_ones_like_dtype",
238+
matcher=lambda sample: "dtype" not in sample.kwargs,
239+
),
240+
)
241+
OP_WITH_SKIPPED_SUBTESTS = frozenset(meta.op_name for meta in SKIP_SUBTESTS)
242+
191243
# END OF SECTION TO MODIFY #####################################################
192244

193245

194246
OPS_DB = copy.deepcopy(common_methods_invocations.op_db)
195247

196248

249+
TORCH_TYPE_TO_ONNX = {
250+
torch.bool: onnx.TensorProto.BOOL,
251+
torch.uint8: onnx.TensorProto.UINT8,
252+
torch.int8: onnx.TensorProto.INT8,
253+
torch.int16: onnx.TensorProto.INT16,
254+
torch.int32: onnx.TensorProto.INT32,
255+
torch.int64: onnx.TensorProto.INT64,
256+
torch.float16: onnx.TensorProto.FLOAT16,
257+
torch.float32: onnx.TensorProto.FLOAT,
258+
torch.float64: onnx.TensorProto.DOUBLE,
259+
torch.complex64: onnx.TensorProto.COMPLEX64,
260+
torch.complex128: onnx.TensorProto.COMPLEX128,
261+
torch.bfloat16: onnx.TensorProto.BFLOAT16,
262+
}
263+
264+
265+
class TestFunctionsCompilation(unittest.TestCase):
266+
"""Test all functions can be compiled."""
267+
268+
@parameterized.parameterized.expand(
269+
list(OPINFO_FUNCTION_MAPPING.items()),
270+
)
271+
def test_function_compiles(self, _, function):
272+
compiled = onnxscript.script()(function)
273+
compiled.to_function_proto()
274+
275+
276+
def _convert_tensor_to_numpy(input: Any) -> Any:
277+
if isinstance(input, torch.Tensor):
278+
return input.detach().cpu().numpy()
279+
if isinstance(input, (tuple, list)):
280+
if len(input) == 0:
281+
return np.array((), dtype=np.int64)
282+
if isinstance(input[0], torch.Tensor):
283+
return [_convert_tensor_to_numpy(x) for x in input]
284+
if isinstance(input[0], (int, float)):
285+
# Just a tuple of numbers
286+
return np.array(input)
287+
return input
288+
289+
return input
290+
291+
292+
def _convert_kwargs_for_onnx(kwargs: dict[str, Any]) -> dict[str, Any]:
293+
"""Converts kwargs to be compatible with ONNX Runtime.
294+
295+
ONNX Runtime doesn't support torch.bool, so we convert them to torch.uint8.
296+
"""
297+
new_kwargs = {}
298+
for key, value in kwargs.items():
299+
if key == "device":
300+
continue
301+
if key == "dtype":
302+
value = TORCH_TYPE_TO_ONNX[value]
303+
new_kwargs[key] = value
304+
return new_kwargs
305+
306+
307+
def _should_skip_test_sample(op_name: str, sample) -> Optional[str]:
308+
"""Returns a reason if a test sample should be skipped."""
309+
if op_name not in OP_WITH_SKIPPED_SUBTESTS:
310+
return None
311+
for decorator_meta in SKIP_SUBTESTS:
312+
# Linear search on SKIP_SUBTESTS. That's fine because the list is small.
313+
if decorator_meta.op_name == op_name:
314+
assert decorator_meta.matcher is not None, "Matcher must be defined"
315+
if decorator_meta.matcher(sample):
316+
return decorator_meta.reason
317+
return None
318+
319+
197320
class TestOutputConsistency(unittest.TestCase):
198321
"""Test output consistency between exported ONNX models and PyTorch eager mode.
199322
@@ -236,10 +359,14 @@ def test_output_match(self, device: str, dtype: torch.dtype, op):
236359
inputs=repr(inputs),
237360
kwargs=repr(cpu_sample.kwargs),
238361
):
239-
input_numpy = [x.numpy() for x in inputs if isinstance(x, torch.Tensor)]
240-
torch_output = op(*inputs, **cpu_sample.kwargs)
362+
skip_reason = _should_skip_test_sample(op.name, cpu_sample)
363+
if skip_reason is not None:
364+
self.skipTest(skip_reason)
365+
input_onnx = [_convert_tensor_to_numpy(x) for x in inputs]
366+
kwargs_onnx = _convert_kwargs_for_onnx(cpu_sample.kwargs)
367+
output_torch = op(*inputs, **cpu_sample.kwargs)
241368
try:
242-
function_output = scripted_function(*input_numpy, **cpu_sample.kwargs)
369+
function_output = scripted_function(*input_onnx, **kwargs_onnx)
243370
# pylint: disable=c-extension-no-member
244371
except onnxruntime.capi.onnxruntime_pybind11_state.NotImplemented:
245372
self.skipTest(
@@ -250,7 +377,7 @@ def test_output_match(self, device: str, dtype: torch.dtype, op):
250377
# Use torch testing to ensure dtypes and shapes match
251378
torch.testing.assert_close(
252379
torch.tensor(function_output),
253-
torch_output,
380+
output_torch,
254381
)
255382

256383

0 commit comments

Comments
 (0)