Skip to content

Commit d99e4b8

Browse files
authored
feat(atenlib): implement aten functions 1/n (#247)
Stack from [ghstack](https://github.com/ezyang/ghstack) (oldest at bottom): * __->__ #247 Implemented ops - lt - gt - round - clamp_min - clamp_max - clamp - repeat - ones_like Test: - Create the ability to skip particular sub tests - Create a helper function to transform torch inputs into numpy for onnxscript to run on
1 parent 566cdc3 commit d99e4b8

File tree

2 files changed

+204
-28
lines changed

2 files changed

+204
-28
lines changed

onnxscript/function_libs/torch_aten/ops/core.py

Lines changed: 72 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -739,24 +739,55 @@ def aten_chunk(self: TensorType, chunks: int, dim: int = 0) -> TensorType:
739739
raise NotImplementedError()
740740

741741

742-
def aten_clamp(
743-
self: TensorType, min: Optional[float] = None, max: Optional[float] = None
744-
) -> TensorType:
742+
def aten_clamp(self: TensorType, min_=None, max_=None) -> TensorType:
745743
# clamp(Tensor self, Scalar? min=None, Scalar? max=None) -> Tensor
746744

747-
raise NotImplementedError()
745+
# TODO(justinchuby): Handle integer inputs
746+
# FIXME(justinchuby): Enable test for this after None values are supported
747+
# TODO(justinchuby): If min is greater than max torch.clamp(..., min, max)
748+
# sets all elements in input to the value of max.
749+
if op.OptionalHasElement(min_):
750+
min_ = op.OptionalGetElement(min_)
751+
min_clamp = op.CastLike(min_, self) # type: ignore[arg-type]
752+
else:
753+
min_clamp = op.Constant(value_float=float("-inf"))
754+
755+
if op.OptionalHasElement(max_):
756+
max_ = op.OptionalGetElement(max_)
757+
max_clamp = op.CastLike(max_, self) # type: ignore[arg-type]
758+
else:
759+
max_clamp = op.Constant(value_float=float("inf"))
760+
761+
# Enforce the lower and upper bounds
762+
clamped = op.Max(op.Min(self, max_clamp), min_clamp) # type: ignore[arg-type]
763+
return clamped
748764

749765

750-
def aten_clamp_max(self: TensorType, max: float) -> TensorType:
766+
def aten_clamp_max_scalar(self, max_):
751767
# clamp_max(Tensor self, Scalar max) -> Tensor
752768

753-
raise NotImplementedError()
769+
max_ = op.CastLike(max_, self)
770+
return op.Clip(self, None, max_)
771+
772+
773+
def aten_clamp_max_tensor(self, max_):
774+
# clamp_max(Tensor self, Scalar max) -> Tensor
775+
776+
return op.Min(self, max_)
754777

755778

756-
def aten_clamp_min(self: TensorType, min: float) -> TensorType:
779+
def aten_clamp_min_scalar(self, min_):
757780
# clamp_min(Tensor self, Scalar min) -> Tensor
781+
# NOTE: min_ is a rank 0 tensor.
782+
# TODO(justinchuby): Specify the type constraints.
783+
min_ = op.CastLike(min_, self)
784+
return op.Clip(self, min_, None)
758785

759-
raise NotImplementedError()
786+
787+
def aten_clamp_min_tensor(self, min_):
788+
# clamp_min(Tensor self, Tensor min) -> Tensor
789+
# TODO(justinchuby): Specify the type constraints.
790+
return op.Max(self, min_)
760791

761792

762793
def aten_clip(
@@ -1958,10 +1989,12 @@ def aten_gru_cell(
19581989
raise NotImplementedError()
19591990

19601991

1961-
def aten_gt(self: TensorType, other: TensorType) -> TensorType:
1992+
def aten_gt(self, other):
19621993
# gt.Tensor(Tensor self, Tensor other) -> Tensor
19631994

1964-
raise NotImplementedError()
1995+
# TODO(justinchuby): Input spec: non bool tensor
1996+
# Boolean inputs can be pre-casted by policy
1997+
return op.Greater(self, other)
19651998

19661999

19672000
def aten_hamming_window(window_length: int) -> TensorType:
@@ -2572,10 +2605,12 @@ def aten_lstm_mps_backward(
25722605
raise NotImplementedError()
25732606

25742607

2575-
def aten_lt(self: TensorType, other: TensorType) -> TensorType:
2608+
def aten_lt(self, other):
25762609
# lt.Tensor(Tensor self, Tensor other) -> Tensor
25772610

2578-
raise NotImplementedError()
2611+
# TODO(justinchuby): Input spec: non bool tensor
2612+
# Boolean inputs can be pre-casted by policy
2613+
return op.Less(self, other)
25792614

25802615

25812616
def aten_lu_solve(self: TensorType, LU_data: TensorType, LU_pivots: TensorType) -> TensorType:
@@ -3440,10 +3475,20 @@ def aten_ones(size: INT64) -> TensorType:
34403475
raise NotImplementedError()
34413476

34423477

3443-
def aten_ones_like(self: TensorType, memory_format: Optional[str] = None) -> TensorType:
3478+
def aten_ones_like(self, dtype: int = -1):
3479+
"""ones_like.
3480+
3481+
Note: dtype is an onnx enum. Users should convert torch dtype to onnx dtype
3482+
before calling this function.
3483+
"""
34443484
# ones_like(Tensor self, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None, MemoryFormat? memory_format=None) -> Tensor
34453485

3446-
raise NotImplementedError()
3486+
shape = op.Shape(self)
3487+
if dtype == -1:
3488+
one = op.CastLike(1, self) # type: ignore[arg-type]
3489+
else:
3490+
one = op.Cast(1, to=dtype) # type: ignore[arg-type]
3491+
return op.Expand(one, shape)
34473492

34483493

34493494
def aten_or(self: TensorType, other: TensorType) -> TensorType:
@@ -3916,10 +3961,19 @@ def aten_renorm(self: TensorType, p: float, dim: int, maxnorm: float) -> TensorT
39163961
raise NotImplementedError()
39173962

39183963

3919-
def aten_repeat(self: TensorType, repeats: INT64) -> TensorType:
3964+
def aten_repeat(self, repeats: INT64):
39203965
# repeat(Tensor self, SymInt[] repeats) -> Tensor
39213966

3922-
raise NotImplementedError()
3967+
# FIXME(justinchuby): When repeats.shape == [0]
3968+
3969+
# TODO(justinchuby): Make ones_like a function when onnxscript supports it
3970+
# shape = ones_like(repeats) := {
3971+
one = op.Constant(value_int=1)
3972+
repeats_shape = op.Shape(repeats)
3973+
shape = op.Expand(one, repeats_shape)
3974+
# }
3975+
self_expanded = op.Expand(self, shape) # type: ignore[arg-type]
3976+
return op.Tile(self_expanded, repeats)
39233977

39243978

39253979
def aten_repeat_interleave(
@@ -4012,10 +4066,10 @@ def aten_rot90(self: TensorType, k: int = 1, dims: Sequence[int] = (0, 1)) -> Te
40124066
raise NotImplementedError()
40134067

40144068

4015-
def aten_round(self: TensorType) -> TensorType:
4069+
def aten_round(self):
40164070
# round(Tensor self) -> Tensor
40174071

4018-
raise NotImplementedError()
4072+
return op.Round(self)
40194073

40204074

40214075
def aten_row_indices(self: TensorType) -> TensorType:

onnxscript/test/function_libs/torch_aten/ops_correctness_test.py

Lines changed: 132 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,9 @@
77
from typing import Any, Callable, Collection, Iterable, Optional, Sequence, TypeVar
88

99
import numpy as np
10+
import onnx
1011
import onnxruntime.capi.onnxruntime_pybind11_state
12+
import parameterized
1113
import torch
1214
from torch.testing._internal import common_device_type, common_methods_invocations
1315
from torch.testing._internal.opinfo import core as opinfo_core
@@ -69,14 +71,15 @@ class DecorateMeta:
6971
decorator: Callable[..., Any]
7072
dtypes: Optional[Collection[torch.dtype]]
7173
reason: str
74+
matcher: Optional[Callable[[Any], bool]] = None
7275

7376

7477
def xfail(
7578
op_name: str,
7679
variant_name: str = "",
7780
*,
81+
reason: str,
7882
dtypes: Optional[Collection[torch.dtype]] = None,
79-
reason: Optional[str] = None,
8083
):
8184
"""Expects an OpInfo test to fail.
8285
@@ -86,8 +89,6 @@ def xfail(
8689
dtypes: The dtypes to expect the failure.
8790
reason: The reason for the failure.
8891
"""
89-
if reason is None:
90-
raise ValueError("Please specify a reason.")
9192
return DecorateMeta(
9293
op_name=op_name,
9394
variant_name=variant_name,
@@ -101,8 +102,9 @@ def skip(
101102
op_name: str,
102103
variant_name: str = "",
103104
*,
105+
reason: str,
104106
dtypes: Optional[Collection[torch.dtype]] = None,
105-
reason: Optional[str] = None,
107+
matcher: Optional[Callable[[Any], Any]] = None,
106108
):
107109
"""Skips an OpInfo test.
108110
@@ -111,15 +113,16 @@ def skip(
111113
variant_name: Optional OpInfo variant_test_name.
112114
dtypes: The dtypes to skip.
113115
reason: The reason for skipping.
116+
matcher: A function that matches the test sample input. It is used only when
117+
xfail is in the SKIP_SUBTESTS list.
114118
"""
115-
if reason is None:
116-
raise ValueError("Please specify a reason.")
117119
return DecorateMeta(
118120
op_name=op_name,
119121
variant_name=variant_name,
120122
decorator=unittest.skip(f"Don't care: {reason}"),
121123
dtypes=dtypes,
122124
reason=reason,
125+
matcher=matcher,
123126
)
124127

125128

@@ -159,17 +162,29 @@ def wrapped(fn):
159162
# Find the names of the OpInfos in torch/testing/_internal/common_methods_invocations.py
160163
OPINFO_FUNCTION_MAPPING: dict[str, Callable[..., Any]] = {
161164
"add": core_ops.aten_add,
165+
# "clamp": core_ops.aten_clamp, # TODO(justinchuby): Enable
166+
"clamp_max": core_ops.aten_clamp_max_tensor,
167+
"clamp_min": core_ops.aten_clamp_min_tensor,
168+
"gt": core_ops.aten_gt,
169+
"lt": core_ops.aten_lt,
162170
"mul": core_ops.aten_mul,
163171
"nn.functional.elu": nn_ops.aten_elu,
164172
"nn.functional.relu6": nn_ops.aten_relu6,
165173
"nn.functional.selu": core_ops.aten_selu,
174+
"ones_like": core_ops.aten_ones_like,
175+
"repeat": core_ops.aten_repeat,
176+
"round": core_ops.aten_round,
166177
"sub": core_ops.aten_sub,
167178
}
168179

169180
TESTED_OPS = frozenset(OPINFO_FUNCTION_MAPPING)
170181

171182
EXPECTED_SKIPS_OR_FAILS = (
172183
xfail("add", dtypes=BOOL_TYPES, reason="Add is not defined on bool tensors"),
184+
xfail("clamp_max", dtypes=BOOL_TYPES, reason="Min is not defined on bool tensors"),
185+
xfail("clamp_min", dtypes=BOOL_TYPES, reason="Max is not defined on bool tensors"),
186+
xfail("gt", dtypes=BOOL_TYPES, reason="Greater is not defined on bool tensors"),
187+
xfail("lt", dtypes=BOOL_TYPES, reason="Less is not defined on bool tensors"),
173188
xfail("mul", dtypes=BOOL_TYPES, reason="Mul is not defined on bool tensors"),
174189
xfail(
175190
"nn.functional.elu",
@@ -186,14 +201,117 @@ def wrapped(fn):
186201
dtypes=dtypes_except(torch.float16, torch.float32),
187202
reason="ONNX Runtime doesn't support float64 for Selu",
188203
),
204+
xfail(
205+
"round",
206+
variant_name="",
207+
dtypes=dtypes_except(*FLOAT_TYPES),
208+
reason="Round is not defined on non-float tensors",
209+
),
210+
xfail("round", variant_name="decimals_0", reason="The ATen op does not support decimals"),
211+
xfail("round", variant_name="decimals_3", reason="The ATen op does not support decimals"),
212+
xfail(
213+
"round", variant_name="decimals_neg_3", reason="The ATen op does not support decimals"
214+
),
189215
xfail("sub", dtypes=BOOL_TYPES, reason="Sub is not defined on bool tensors"),
190216
)
217+
218+
219+
SKIP_SUBTESTS = (
220+
skip(
221+
"clamp_max",
222+
reason="Empty tensor not yet supported",
223+
matcher=lambda sample: sample.input.size() == torch.Size([0]),
224+
),
225+
skip(
226+
"clamp_min",
227+
reason="Empty tensor not yet supported",
228+
matcher=lambda sample: sample.input.size() == torch.Size([0]),
229+
),
230+
skip(
231+
"repeat",
232+
reason="repeating when input is a scalar and repeats is empty is not supported",
233+
matcher=lambda sample: sample.args[0] == (),
234+
),
235+
)
236+
OP_WITH_SKIPPED_SUBTESTS = frozenset(meta.op_name for meta in SKIP_SUBTESTS)
237+
191238
# END OF SECTION TO MODIFY #####################################################
192239

193240

194241
OPS_DB = copy.deepcopy(common_methods_invocations.op_db)
195242

196243

244+
TORCH_TYPE_TO_ONNX = {
245+
torch.bool: onnx.TensorProto.BOOL,
246+
torch.uint8: onnx.TensorProto.UINT8,
247+
torch.int8: onnx.TensorProto.INT8,
248+
torch.int16: onnx.TensorProto.INT16,
249+
torch.int32: onnx.TensorProto.INT32,
250+
torch.int64: onnx.TensorProto.INT64,
251+
torch.float16: onnx.TensorProto.FLOAT16,
252+
torch.float32: onnx.TensorProto.FLOAT,
253+
torch.float64: onnx.TensorProto.DOUBLE,
254+
torch.complex64: onnx.TensorProto.COMPLEX64,
255+
torch.complex128: onnx.TensorProto.COMPLEX128,
256+
torch.bfloat16: onnx.TensorProto.BFLOAT16,
257+
}
258+
259+
260+
class TestFunctionsCompilation(unittest.TestCase):
261+
"""Test all functions can be compiled."""
262+
263+
@parameterized.parameterized.expand(
264+
list(OPINFO_FUNCTION_MAPPING.items()),
265+
)
266+
def test_function_compiles(self, _, function):
267+
compiled = onnxscript.script()(function)
268+
compiled.to_function_proto()
269+
270+
271+
def _convert_tensor_to_numpy(input: Any) -> Any:
272+
if isinstance(input, torch.Tensor):
273+
return input.detach().cpu().numpy()
274+
if isinstance(input, (tuple, list)):
275+
if len(input) == 0:
276+
return np.array((), dtype=np.int64)
277+
if isinstance(input[0], torch.Tensor):
278+
return [_convert_tensor_to_numpy(x) for x in input]
279+
if isinstance(input[0], (int, float)):
280+
# Just a tuple of numbers
281+
return np.array(input)
282+
return input
283+
284+
return input
285+
286+
287+
def _convert_kwargs_for_onnx(kwargs: dict[str, Any]) -> dict[str, Any]:
288+
"""Converts kwargs to be compatible with ONNX Runtime.
289+
290+
ONNX Runtime doesn't support torch.bool, so we convert them to torch.uint8.
291+
"""
292+
new_kwargs = {}
293+
for key, value in kwargs.items():
294+
if key == "device":
295+
continue
296+
if key == "dtype":
297+
value = TORCH_TYPE_TO_ONNX[value]
298+
new_kwargs[key] = value
299+
return new_kwargs
300+
301+
302+
def _should_skip_test_sample(op_name: str, sample) -> Optional[str]:
303+
"""Returns a reason if a test sample should be skipped."""
304+
if op_name not in OP_WITH_SKIPPED_SUBTESTS:
305+
return None
306+
for decorator_meta in SKIP_SUBTESTS:
307+
# Linear search on SKIP_SUBTESTS. That's fine because the list is small.
308+
if decorator_meta.op_name == op_name:
309+
assert decorator_meta.matcher is not None, "Matcher must be defined"
310+
if decorator_meta.matcher(sample):
311+
return decorator_meta.reason
312+
return None
313+
314+
197315
class TestOutputConsistency(unittest.TestCase):
198316
"""Test output consistency between exported ONNX models and PyTorch eager mode.
199317
@@ -236,10 +354,14 @@ def test_output_match(self, device: str, dtype: torch.dtype, op):
236354
inputs=repr(inputs),
237355
kwargs=repr(cpu_sample.kwargs),
238356
):
239-
input_numpy = [x.numpy() for x in inputs if isinstance(x, torch.Tensor)]
240-
torch_output = op(*inputs, **cpu_sample.kwargs)
357+
skip_reason = _should_skip_test_sample(op.name, cpu_sample)
358+
if skip_reason is not None:
359+
self.skipTest(skip_reason)
360+
input_onnx = [_convert_tensor_to_numpy(x) for x in inputs]
361+
kwargs_onnx = _convert_kwargs_for_onnx(cpu_sample.kwargs)
362+
output_torch = op(*inputs, **cpu_sample.kwargs)
241363
try:
242-
function_output = scripted_function(*input_numpy, **cpu_sample.kwargs)
364+
function_output = scripted_function(*input_onnx, **kwargs_onnx)
243365
# pylint: disable=c-extension-no-member
244366
except onnxruntime.capi.onnxruntime_pybind11_state.NotImplemented:
245367
self.skipTest(
@@ -250,7 +372,7 @@ def test_output_match(self, device: str, dtype: torch.dtype, op):
250372
# Use torch testing to ensure dtypes and shapes match
251373
torch.testing.assert_close(
252374
torch.tensor(function_output),
253-
torch_output,
375+
output_torch,
254376
)
255377

256378

0 commit comments

Comments
 (0)