Skip to content

Commit f80920b

Browse files
authored
Merge branch 'main' into new_ops_2
2 parents c30e0c8 + edfa7c1 commit f80920b

File tree

4 files changed

+61
-13
lines changed

4 files changed

+61
-13
lines changed

onnxscript/function_libs/torch_aten/ops/core.py

Lines changed: 20 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -186,15 +186,15 @@ def aten_alpha_dropout(input: TensorType, p: float, train: bool) -> TensorType:
186186
raise NotImplementedError()
187187

188188

189-
# @torch_op("aten::amax") # FIXME(#249): Uncomment when CI uses onnx 1.13
189+
@torch_op("aten::amax")
190190
def aten_amax(self: TReal, dim: INT64, keepdim: int = 0) -> TReal:
191191
# amax(Tensor self, int[1] dim=[], bool keepdim=False) -> Tensor
192192

193193
# TODO(justinchuby): Make dim optional, keepdim bool
194194
return op.ReduceMax(self, dim, keepdims=keepdim)
195195

196196

197-
# @torch_op("aten::amin") # FIXME(#249): Uncomment when CI uses onnx 1.13
197+
@torch_op("aten::amin")
198198
def aten_amin(self: TReal, dim: INT64, keepdim: int = 0) -> TReal:
199199
# amin(Tensor self, int[1] dim=[], bool keepdim=False) -> Tensor
200200

@@ -2703,7 +2703,7 @@ def aten_logspace(start: float, end: float, steps: int, base: float = 10.0) -> T
27032703
raise NotImplementedError()
27042704

27052705

2706-
@torch_op("aten::logsumexp", trace_only=True) # FIXME(#249): Script when CI uses onnx 1.13
2706+
@torch_op("aten::logsumexp")
27072707
def aten_logsumexp(self: TReal, dim: INT64, keepdim: int = False) -> TReal:
27082708
# logsumexp(Tensor self, int[1] dim, bool keepdim=False) -> Tensor
27092709

@@ -4408,16 +4408,30 @@ def aten_sinh(self: TFloat) -> TFloat:
44084408
return op.Sinh(self)
44094409

44104410

4411+
@torch_op("aten::slice")
44114412
def aten_slice(
4412-
self: TensorType,
4413+
self: TTensor,
44134414
dim: int = 0,
44144415
start: Optional[INT64] = None,
44154416
end: Optional[INT64] = None,
44164417
step: INT64 = 1,
4417-
) -> TensorType:
4418+
) -> TTensor:
44184419
# slice.Tensor(Tensor(a) self, int dim=0, SymInt? start=None, SymInt? end=None, SymInt step=1) -> Tensor(a)
44194420

4420-
raise NotImplementedError()
4421+
# TODO: using OptionalHasElement() to check start/end value
4422+
start = op.Cast(start, to=INT64.dtype)
4423+
start = op.Reshape(start, op.Constant(value_ints=[-1]))
4424+
4425+
end = op.Cast(end, to=INT64.dtype)
4426+
end = op.Reshape(end, op.Constant(value_ints=[-1]))
4427+
4428+
dim = op.Cast(dim, to=INT64.dtype)
4429+
dim = op.Reshape(dim, op.Constant(value_ints=[-1]))
4430+
4431+
step = op.Cast(step, to=INT64.dtype)
4432+
step = op.Reshape(step, op.Constant(value_ints=[-1]))
4433+
4434+
return op.Slice(self, start, end, dim, step)
44214435

44224436

44234437
def aten_slice_backward(

onnxscript/test/function_libs/torch_aten/ops_correctness_test.py

Lines changed: 39 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -221,7 +221,9 @@ def _topk_input_wrangler(
221221

222222
# Ops to be tested for numerical consistency between onnx and pytorch
223223
# Find the names of the OpInfos in torch/testing/_internal/common_methods_invocations.py
224-
OPINFO_FUNCTION_MAPPING: dict[
224+
225+
# Split the scripted and traced ops to make sure we don't forget to script an op
226+
OPINFO_FUNCTION_MAPPING_SCRIPTED: dict[
225227
str,
226228
onnxscript.OnnxFunction
227229
| Callable[..., Any]
@@ -245,7 +247,6 @@ def _topk_input_wrangler(
245247
"atan": core_ops.aten_atan,
246248
"atanh": core_ops.aten_atanh,
247249
"bmm": core_ops.aten_bmm,
248-
"cat": core_ops.aten_cat,
249250
"ceil": core_ops.aten_ceil,
250251
"clamp_max": core_ops.aten_clamp_max,
251252
"clamp_min": core_ops.aten_clamp_min,
@@ -267,18 +268,17 @@ def _topk_input_wrangler(
267268
"full": (core_ops.aten_full, _full_input_wrangler),
268269
"full_like": core_ops.aten_full_like,
269270
"gt": core_ops.aten_gt,
270-
"index_select": core_ops.aten_index_select,
271271
"isinf": core_ops.aten_isinf,
272272
"log": core_ops.aten_log,
273273
"log10": core_ops.aten_log10,
274274
"log1p": core_ops.aten_log1p,
275+
"log_softmax": (special_ops.aten_special_log_softmax, _log_softmax_input_wrangler),
275276
"log2": core_ops.aten_log2,
276277
"logaddexp": core_ops.aten_logaddexp,
277278
"logaddexp2": core_ops.aten_logaddexp2,
278279
"logcumsumexp": core_ops.aten_logcumsumexp,
279280
"logdet": core_ops.aten_logdet,
280281
"logsumexp": (core_ops.aten_logsumexp, _logcumsumexp_input_wrangler),
281-
"log_softmax": (special_ops.aten_special_log_softmax, _log_softmax_input_wrangler),
282282
"lt": core_ops.aten_lt,
283283
"matmul": core_ops.aten_matmul,
284284
"mm": core_ops.aten_mm,
@@ -315,12 +315,12 @@ def _topk_input_wrangler(
315315
"sign": core_ops.aten_sign,
316316
"sin": core_ops.aten_sin,
317317
"sinh": core_ops.aten_sinh,
318+
"slice": core_ops.aten_slice,
318319
"sqrt": core_ops.aten_sqrt,
319320
"sub": core_ops.aten_sub,
320321
"t": core_ops.aten_t,
321322
"tan": core_ops.aten_tan,
322323
"tanh": core_ops.aten_tanh,
323-
"transpose": core_ops.aten_transpose,
324324
"topk": (
325325
core_ops.aten_topk,
326326
_topk_input_wrangler,
@@ -333,6 +333,26 @@ def _topk_input_wrangler(
333333
"zeros_like": core_ops.aten_zeros_like,
334334
}
335335

336+
337+
OPINFO_FUNCTION_MAPPING_TRACE_ONLY: dict[
338+
str,
339+
Callable[..., Any] | tuple[Callable[..., Any], Callable[..., Any]],
340+
] = {
341+
"cat": core_ops.aten_cat,
342+
"index_select": core_ops.aten_index_select,
343+
"transpose": core_ops.aten_transpose,
344+
}
345+
346+
OPINFO_FUNCTION_MAPPING: dict[
347+
str,
348+
onnxscript.OnnxFunction
349+
| Callable[..., Any]
350+
| tuple[
351+
onnxscript.OnnxFunction | Callable[..., Any],
352+
Callable[[list[Any], dict[str, Any]], tuple[list[Any], dict[str, Any]]],
353+
],
354+
] = {**OPINFO_FUNCTION_MAPPING_SCRIPTED, **OPINFO_FUNCTION_MAPPING_TRACE_ONLY}
355+
336356
TESTED_OPS = frozenset(OPINFO_FUNCTION_MAPPING)
337357

338358
EXPECTED_SKIPS_OR_FAILS = (
@@ -420,6 +440,12 @@ def _topk_input_wrangler(
420440
matcher=lambda sample: "scale_factor" in sample.kwargs,
421441
reason="fixme: the scale_factor tests",
422442
),
443+
skip(
444+
"slice",
445+
# kwargs {dim, start, end, step} is empty, we cannot give the default value
446+
matcher=lambda sample: len(sample.kwargs) == 0,
447+
reason="start and end must be 1-D array, cannot be optional, due to ort 1.13 does not support yet",
448+
),
423449
)
424450

425451
duplicate_opinfo(
@@ -523,6 +549,14 @@ def setUp(self) -> None:
523549
torch.manual_seed(42)
524550
np.random.seed(42)
525551

552+
def test_all_script_functions_are_onnx_functions(self):
553+
for func_with_wrangler in OPINFO_FUNCTION_MAPPING_SCRIPTED.values():
554+
if isinstance(func_with_wrangler, tuple):
555+
func = func_with_wrangler[0]
556+
else:
557+
func = func_with_wrangler
558+
self.assertIsInstance(func, onnxscript.OnnxFunction)
559+
526560
@common_device_type.ops( # type: ignore[misc]
527561
[info for info in OPS_DB if info.name in TESTED_OPS],
528562
allowed_dtypes=TESTED_DTYPES,

onnxscript/test/onnx_backend_test.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,8 @@
33
# Licensed under the MIT License.
44
# --------------------------------------------------------------------------
55

6+
# pylint: disable=too-many-boolean-expressions
7+
68
import importlib
79
import os
810
import unittest

pyproject_pylint.toml

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -19,9 +19,7 @@ disable = [
1919
"redefined-builtin", # TODO: should we avoid redefined-builtin?
2020
"too-few-public-methods",
2121
"too-many-arguments",
22-
"too-many-boolean-expressions",
2322
"too-many-branches",
24-
"too-many-function-args",
2523
"too-many-instance-attributes",
2624
"too-many-lines",
2725
"too-many-locals",

0 commit comments

Comments
 (0)