Skip to content

Commit cf27ba8

Browse files
authored
test(atenlib): split mappings to two (#308)
Split the scripted and traced ops and added a test to make sure we compiled all functions.
1 parent 91701fa commit cf27ba8

File tree

4 files changed

+37
-10
lines changed

4 files changed

+37
-10
lines changed

onnxscript/function_libs/torch_aten/ops/core.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -186,15 +186,15 @@ def aten_alpha_dropout(input: TensorType, p: float, train: bool) -> TensorType:
186186
raise NotImplementedError()
187187

188188

189-
# @torch_op("aten::amax") # FIXME(#249): Uncomment when CI uses onnx 1.13
189+
@torch_op("aten::amax")
190190
def aten_amax(self: TReal, dim: INT64, keepdim: int = 0) -> TReal:
191191
# amax(Tensor self, int[1] dim=[], bool keepdim=False) -> Tensor
192192

193193
# TODO(justinchuby): Make dim optional, keepdim bool
194194
return op.ReduceMax(self, dim, keepdims=keepdim)
195195

196196

197-
# @torch_op("aten::amin") # FIXME(#249): Uncomment when CI uses onnx 1.13
197+
@torch_op("aten::amin")
198198
def aten_amin(self: TReal, dim: INT64, keepdim: int = 0) -> TReal:
199199
# amin(Tensor self, int[1] dim=[], bool keepdim=False) -> Tensor
200200

@@ -2704,7 +2704,7 @@ def aten_logspace(start: float, end: float, steps: int, base: float = 10.0) -> T
27042704
raise NotImplementedError()
27052705

27062706

2707-
@torch_op("aten::logsumexp", trace_only=True) # FIXME(#249): Script when CI uses onnx 1.13
2707+
@torch_op("aten::logsumexp")
27082708
def aten_logsumexp(self: TReal, dim: INT64, keepdim: int = False) -> TReal:
27092709
# logsumexp(Tensor self, int[1] dim, bool keepdim=False) -> Tensor
27102710

onnxscript/test/function_libs/torch_aten/ops_correctness_test.py

Lines changed: 32 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -221,7 +221,9 @@ def _topk_input_wrangler(
221221

222222
# Ops to be tested for numerical consistency between onnx and pytorch
223223
# Find the names of the OpInfos in torch/testing/_internal/common_methods_invocations.py
224-
OPINFO_FUNCTION_MAPPING: dict[
224+
225+
# Split the scripted and traced ops to make sure we don't forget to script an op
226+
OPINFO_FUNCTION_MAPPING_SCRIPTED: dict[
225227
str,
226228
onnxscript.OnnxFunction
227229
| Callable[..., Any]
@@ -245,7 +247,6 @@ def _topk_input_wrangler(
245247
"atan": core_ops.aten_atan,
246248
"atanh": core_ops.aten_atanh,
247249
"bmm": core_ops.aten_bmm,
248-
"cat": core_ops.aten_cat,
249250
"ceil": core_ops.aten_ceil,
250251
"clamp_max": core_ops.aten_clamp_max,
251252
"clamp_min": core_ops.aten_clamp_min,
@@ -267,18 +268,17 @@ def _topk_input_wrangler(
267268
"full": (core_ops.aten_full, _full_input_wrangler),
268269
"full_like": core_ops.aten_full_like,
269270
"gt": core_ops.aten_gt,
270-
"index_select": core_ops.aten_index_select,
271271
"isinf": core_ops.aten_isinf,
272272
"log": core_ops.aten_log,
273273
"log10": core_ops.aten_log10,
274274
"log1p": core_ops.aten_log1p,
275+
"log_softmax": (special_ops.aten_special_log_softmax, _log_softmax_input_wrangler),
275276
"log2": core_ops.aten_log2,
276277
"logaddexp": core_ops.aten_logaddexp,
277278
"logaddexp2": core_ops.aten_logaddexp2,
278279
"logcumsumexp": core_ops.aten_logcumsumexp,
279280
"logdet": core_ops.aten_logdet,
280281
"logsumexp": (core_ops.aten_logsumexp, _logcumsumexp_input_wrangler),
281-
"log_softmax": (special_ops.aten_special_log_softmax, _log_softmax_input_wrangler),
282282
"lt": core_ops.aten_lt,
283283
"matmul": core_ops.aten_matmul,
284284
"mm": core_ops.aten_mm,
@@ -319,7 +319,6 @@ def _topk_input_wrangler(
319319
"t": core_ops.aten_t,
320320
"tan": core_ops.aten_tan,
321321
"tanh": core_ops.aten_tanh,
322-
"transpose": core_ops.aten_transpose,
323322
"topk": (
324323
core_ops.aten_topk,
325324
_topk_input_wrangler,
@@ -332,6 +331,26 @@ def _topk_input_wrangler(
332331
"zeros_like": core_ops.aten_zeros_like,
333332
}
334333

334+
335+
OPINFO_FUNCTION_MAPPING_TRACE_ONLY: dict[
336+
str,
337+
Callable[..., Any] | tuple[Callable[..., Any], Callable[..., Any]],
338+
] = {
339+
"cat": core_ops.aten_cat,
340+
"index_select": core_ops.aten_index_select,
341+
"transpose": core_ops.aten_transpose,
342+
}
343+
344+
OPINFO_FUNCTION_MAPPING: dict[
345+
str,
346+
onnxscript.OnnxFunction
347+
| Callable[..., Any]
348+
| tuple[
349+
onnxscript.OnnxFunction | Callable[..., Any],
350+
Callable[[list[Any], dict[str, Any]], tuple[list[Any], dict[str, Any]]],
351+
],
352+
] = {**OPINFO_FUNCTION_MAPPING_SCRIPTED, **OPINFO_FUNCTION_MAPPING_TRACE_ONLY}
353+
335354
TESTED_OPS = frozenset(OPINFO_FUNCTION_MAPPING)
336355

337356
EXPECTED_SKIPS_OR_FAILS = (
@@ -522,6 +541,14 @@ def setUp(self) -> None:
522541
torch.manual_seed(42)
523542
np.random.seed(42)
524543

544+
def test_all_script_functions_are_onnx_functions(self):
545+
for func_with_wrangler in OPINFO_FUNCTION_MAPPING_SCRIPTED.values():
546+
if isinstance(func_with_wrangler, tuple):
547+
func = func_with_wrangler[0]
548+
else:
549+
func = func_with_wrangler
550+
self.assertIsInstance(func, onnxscript.OnnxFunction)
551+
525552
@common_device_type.ops( # type: ignore[misc]
526553
[info for info in OPS_DB if info.name in TESTED_OPS],
527554
allowed_dtypes=TESTED_DTYPES,

onnxscript/test/onnx_backend_test.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,8 @@
33
# Licensed under the MIT License.
44
# --------------------------------------------------------------------------
55

6+
# pylint: disable=too-many-boolean-expressions
7+
68
import importlib
79
import os
810
import unittest

pyproject_pylint.toml

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -19,9 +19,7 @@ disable = [
1919
"redefined-builtin", # TODO: should we avoid redefined-builtin?
2020
"too-few-public-methods",
2121
"too-many-arguments",
22-
"too-many-boolean-expressions",
2322
"too-many-branches",
24-
"too-many-function-args",
2523
"too-many-instance-attributes",
2624
"too-many-lines",
2725
"too-many-locals",

0 commit comments

Comments
 (0)