Skip to content

Commit a89a2a9

Browse files
authored
Create env to test with TORCHLIB_EXPERIMENTAL_PREFER_TRACING on | test(torchlib) (#1180)
Stack from [ghstack](https://github.com/ezyang/ghstack) (oldest at bottom): * #1178 * #1177 * #1176 * __->__ #1180 ### Changes - Create env to test with TORCHLIB_EXPERIMENTAL_PREFER_TRACING on - Test with Python 3.11 as well - Fixes #1061 - Fix aten::any.dims and aten::all.dims
1 parent 77ef131 commit a89a2a9

File tree

6 files changed

+145
-33
lines changed

6 files changed

+145
-33
lines changed

.github/workflows/main.yaml

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,8 @@ jobs:
3131
- py310-torch-nightly
3232
- py310-onnx-weekly
3333
- py310-ort-nightly
34+
- py311-ort-nightly
35+
- py310-experimental-torchlib-tracing
3436
include:
3537
- name: py310
3638
python-version: "3.10"
@@ -50,6 +52,12 @@ jobs:
5052
- name: py310-ort-nightly
5153
python-version: "3.10"
5254
nox-tag: test-ort-nightly
55+
- name: py311-ort-nightly
56+
python-version: "3.11"
57+
nox-tag: test-ort-nightly
58+
- name: py310-experimental-torchlib-tracing
59+
python-version: "3.10"
60+
nox-tag: test-experimental-torchlib-tracing
5361
runs-on: ${{ matrix.os }}
5462
steps:
5563
- uses: actions/checkout@v4

noxfile.py

Lines changed: 18 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111

1212
COMMON_TEST_DEPENDENCIES = (
1313
"jinja2",
14-
"numpy==1.23.5",
14+
"numpy==1.24.4",
1515
"typing_extensions",
1616
"beartype!=0.16.0",
1717
"types-PyYAML",
@@ -95,3 +95,20 @@ def test_ort_nightly(session):
9595
session.install(".", "--no-deps")
9696
session.run("pip", "list")
9797
session.run("pytest", "onnxscript", *session.posargs)
98+
99+
100+
@nox.session(tags=["test-experimental-torchlib-tracing"])
101+
def test_experimental_torchlib_tracing(session):
102+
"""Test TorchLib with the experimental TORCHLIB_EXPERIMENTAL_PREFER_TRACING flag on."""
103+
session.install(
104+
*COMMON_TEST_DEPENDENCIES, PYTORCH, ONNX, *ONNX_RUNTIME_NIGHTLY_DEPENDENCIES
105+
)
106+
session.install("-r", "requirements/ci/requirements-ort-nightly.txt")
107+
session.install(".", "--no-deps")
108+
session.run("pip", "list")
109+
session.run(
110+
"pytest",
111+
"onnxscript/tests/function_libs/torch_lib/ops_test.py",
112+
*session.posargs,
113+
env={"TORCHLIB_EXPERIMENTAL_PREFER_TRACING": "1"},
114+
)

onnxscript/function_libs/torch_lib/ops/core.py

Lines changed: 10 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -365,7 +365,9 @@ def aten_all_dims(self: TTensor, dim: Sequence[int] = (), keepdim: bool = False)
365365
if not dim:
366366
return aten_all_dims_no_dim(self, keepdim)
367367
for d in dim:
368-
self = aten_all_dim(self, d, keepdim)
368+
self = aten_all_dim(self, d, keepdim=True)
369+
if not keepdim:
370+
self = op.Squeeze(self, list(dim))
369371
return self
370372

371373

@@ -488,7 +490,9 @@ def aten_any_dims(self: TTensor, dim: Sequence[int] = (), keepdim: bool = False)
488490
if not dim:
489491
return aten_any_dims_no_dim(self, keepdim)
490492
for d in dim:
491-
self = aten_any_dim(self, d, keepdim)
493+
self = aten_any_dim(self, d, keepdim=True)
494+
if not keepdim:
495+
self = op.Squeeze(self, list(dim))
492496
return self
493497

494498

@@ -7339,17 +7343,16 @@ def aten_smm(self: TensorType, mat2: TensorType) -> TensorType:
73397343
raise NotImplementedError()
73407344

73417345

7342-
@torch_op(("aten::softmax", "aten::softmax.int", "aten::special_softmax"))
7343-
def aten_softmax(
7344-
self: TFloatOrBFloat16, dim: int, dtype: int = FLOAT.dtype
7345-
) -> TFloatOrBFloat16:
7346+
@torch_op(("aten::softmax", "aten::softmax.int", "aten::special_softmax"), trace_only=True)
7347+
def aten_softmax(self: TFloatOrBFloat16, dim: int, dtype: int = -1) -> TFloatOrBFloat16:
73467348
"""softmax(Tensor self, int dim, ScalarType? dtype=None) -> Tensor"""
73477349

73487350
self_is_scalar = IsScalar(self)
73497351
if self_is_scalar:
73507352
self = op.Unsqueeze(self, op.Constant(value_ints=[0]))
73517353
result = op.Softmax(self, axis=dim)
7352-
result = op.Cast(result, to=dtype)
7354+
if dtype != -1:
7355+
result = op.Cast(result, to=dtype)
73537356
if self_is_scalar:
73547357
# Convert to scalar when input is scalar
73557358
result = op.Squeeze(result)

onnxscript/function_libs/torch_lib/ops/special.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,6 @@
1313

1414
from typing import Optional, Sequence
1515

16-
from onnxscript import FLOAT
1716
from onnxscript.function_libs.torch_lib.ops import common as common_ops
1817
from onnxscript.function_libs.torch_lib.registration import torch_op
1918
from onnxscript.function_libs.torch_lib.tensor_typing import TFloatOrBFloat16
@@ -212,17 +211,18 @@ def aten_special_log_ndtr(self: TensorType) -> TensorType:
212211
raise NotImplementedError()
213212

214213

215-
@torch_op(("aten::log_softmax", "aten::special_log_softmax"))
214+
@torch_op(("aten::log_softmax", "aten::special_log_softmax"), trace_only=True)
216215
def aten_special_log_softmax(
217-
self: TFloatOrBFloat16, dim: int, dtype: int = FLOAT.dtype
216+
self: TFloatOrBFloat16, dim: int, dtype: int = -1
218217
) -> TFloatOrBFloat16:
219218
"""special_log_softmax(Tensor self, int dim, *, ScalarType? dtype=None) -> Tensor"""
220219

221220
self_is_scalar = IsScalar(self)
222221
if self_is_scalar:
223222
self = op.Unsqueeze(self, op.Constant(value_ints=[0]))
224223
result = op.LogSoftmax(self, axis=dim)
225-
result = op.Cast(result, to=dtype)
224+
if dtype != -1:
225+
result = op.Cast(result, to=dtype)
226226
if self_is_scalar: # squeeze to scalar due to input is scalar
227227
result = op.Squeeze(result)
228228
return result

onnxscript/tests/function_libs/torch_lib/ops_test_common.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -161,6 +161,10 @@ def add_decorate_info(
161161
ops_mapping = {(info.name, info.variant_test_name): info for info in all_opinfos}
162162
for decorate_meta in skip_or_xfails:
163163
opinfo = ops_mapping.get((decorate_meta.op_name, decorate_meta.variant_name))
164+
if opinfo is None and not decorate_meta.enabled_if:
165+
# If the OpInfo doesn't exist and it is not enabled, we skip the OpInfo
166+
# because it could be an OpInfo that is in torch-nightly but not older versions.
167+
continue
164168
assert (
165169
opinfo is not None
166170
), f"Couldn't find OpInfo for {decorate_meta}. Did you need to specify variant_name?"

onnxscript/tests/function_libs/torch_lib/ops_test_data.py

Lines changed: 101 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -471,21 +471,41 @@ def _where_input_wrangler(
471471
),
472472
TorchLibOpInfo("ops.aten._log_softmax", core_ops.aten__log_softmax),
473473
TorchLibOpInfo(
474-
"ops.aten._log_softmax_half", core_ops.aten__log_softmax_half, trace_only=True
475-
).xfail(
474+
"ops.aten._log_softmax_half",
475+
core_ops.aten__log_softmax_half,
476+
trace_only=True,
477+
tolerance={torch.float16: (1e-3, 1e-3)},
478+
)
479+
.xfail(
476480
reason="PyTorch does not implement _log_softmax for float16 on CPU",
477481
dtypes=(torch.float16,),
482+
enabled_if=version_utils.torch_older_than("2.2"),
483+
)
484+
.xfail(
485+
dtypes=(torch.float16,),
486+
reason="fixme: ORT failed. https://github.com/microsoft/onnxruntime/issues/16438",
487+
test_class_name="TestOutputConsistencyFullGraph",
478488
),
479489
TorchLibOpInfo("ops.aten._softmax", core_ops.aten__softmax, trace_only=True),
480-
TorchLibOpInfo(
481-
"ops.aten._softmax_half", core_ops.aten__softmax_half, trace_only=True
482-
).xfail(
490+
TorchLibOpInfo("ops.aten._softmax_half", core_ops.aten__softmax_half, trace_only=True)
491+
.xfail(
483492
reason="PyTorch does not implement _softmax for float16 on CPU",
484493
dtypes=(torch.float16,),
494+
enabled_if=version_utils.torch_older_than("2.2"),
495+
)
496+
.xfail(
497+
dtypes=(torch.float16,),
498+
reason="fixme: ORT failed. https://github.com/microsoft/onnxruntime/issues/16438",
499+
test_class_name="TestOutputConsistencyFullGraph",
500+
),
501+
TorchLibOpInfo("all_dim", core_ops.aten_all_dim).skip(
502+
matcher=lambda sample: not (len(sample.kwargs) > 0)
503+
or isinstance(sample.kwargs.get("dim"), tuple),
504+
reason="this Aten overload only support one tensor as input and {dim,keepdim} as kwargs by design. dim must be an integer",
485505
),
486-
TorchLibOpInfo("all_dim", core_ops.aten_all_dim).xfail(
487-
matcher=lambda sample: not (len(sample.kwargs) > 0),
488-
reason="this Aten overload only support one tensor as input and {dim,keepdim} as kwargs by design",
506+
TorchLibOpInfo("all_dims", core_ops.aten_all_dims, trace_only=True).skip(
507+
matcher=lambda sample: not isinstance(sample.kwargs.get("dim"), tuple),
508+
reason="this overload requires dim to be a tuple",
489509
),
490510
TorchLibOpInfo("allclose", core_ops.aten_allclose),
491511
TorchLibOpInfo(
@@ -501,7 +521,11 @@ def _where_input_wrangler(
501521
TorchLibOpInfo("acosh", core_ops.aten_acosh),
502522
TorchLibOpInfo("add", core_ops.aten_add, tolerance={torch.float16: (1e-3, 1e-3)}),
503523
TorchLibOpInfo("add", core_ops.aten_add_complex, complex=True, trace_only=True),
504-
TorchLibOpInfo("addbmm", core_ops.aten_addbmm, tolerance={torch.float32: (2e-5, 2e-5)}),
524+
TorchLibOpInfo(
525+
"addbmm",
526+
core_ops.aten_addbmm,
527+
tolerance={torch.float32: (2e-5, 2e-5), torch.float16: (2e-1, 2e-2)},
528+
),
505529
TorchLibOpInfo("addcdiv", core_ops.aten_addcdiv),
506530
TorchLibOpInfo("addcmul", core_ops.aten_addcmul, tolerance={torch.float16: (4e-3, 3e-3)}),
507531
TorchLibOpInfo("addmm", core_ops.aten_addmm)
@@ -522,7 +546,7 @@ def _where_input_wrangler(
522546
dtypes=(torch.int16, torch.int32, torch.int64),
523547
reason="ONNX Runtime does not support int inputs to Gemm",
524548
),
525-
TorchLibOpInfo("addmv", core_ops.aten_addmv),
549+
TorchLibOpInfo("addmv", core_ops.aten_addmv, tolerance={torch.float16: (1e-3, 1e-2)}),
526550
TorchLibOpInfo(
527551
"addr",
528552
core_ops.aten_addr,
@@ -557,8 +581,13 @@ def _where_input_wrangler(
557581
"any_dim",
558582
core_ops.aten_any_dim,
559583
).skip(
560-
matcher=lambda sample: not (len(sample.kwargs) > 0),
561-
reason="this Aten overload only support one tensor as input and {dim,keepdim} as kwargs by design",
584+
matcher=lambda sample: not (len(sample.kwargs) > 0)
585+
or isinstance(sample.kwargs.get("dim"), tuple),
586+
reason="this Aten overload only support one tensor as input and {dim,keepdim} as kwargs by design. dim must be an integer",
587+
),
588+
TorchLibOpInfo("any_dims", core_ops.aten_any_dims, trace_only=True).skip(
589+
matcher=lambda sample: not isinstance(sample.kwargs.get("dim"), tuple),
590+
reason="this overload requires dim to be a tuple",
562591
),
563592
TorchLibOpInfo("asin", core_ops.aten_asin),
564593
TorchLibOpInfo("asinh", core_ops.aten_asinh),
@@ -640,7 +669,7 @@ def _where_input_wrangler(
640669
"https://github.com/microsoft/onnxscript/issues/1007"
641670
),
642671
),
643-
TorchLibOpInfo("baddbmm", core_ops.aten_baddbmm),
672+
TorchLibOpInfo("baddbmm", core_ops.aten_baddbmm, tolerance={torch.float16: (1e-3, 1e-2)}),
644673
TorchLibOpInfo("bernoulli", core_ops.aten_bernoulli, nondeterministic=True),
645674
TorchLibOpInfo(
646675
# This string is a unique ID. In extra_opinfo.py, we
@@ -845,6 +874,12 @@ def _where_input_wrangler(
845874
dtypes=(torch.int64, torch.int32),
846875
reason="fixme: Results do not match with PyTorch. https://github.com/microsoft/onnxscript/issues/854",
847876
)
877+
.xfail(
878+
variant_name="tensor_overload",
879+
dtypes=(torch.int64, torch.int32, torch.float16),
880+
reason="fixme: Results do not match with PyTorch. https://github.com/microsoft/onnxscript/issues/854",
881+
enabled_if=not version_utils.torch_older_than("2.2"),
882+
)
848883
.xfail(
849884
dtypes=(torch.float16,),
850885
reason="op 'Range' doesn't support float16.",
@@ -861,17 +896,35 @@ def _where_input_wrangler(
861896
TorchLibOpInfo(
862897
"log_softmax",
863898
special_ops.aten_special_log_softmax,
899+
trace_only=True,
864900
tolerance={torch.float32: (3.7e-5, 1.8e-4), torch.float16: (4e-4, 6e-3)},
865-
).xfail(
901+
)
902+
.xfail(
903+
dtypes=(torch.float16,),
904+
reason="fixme: ORT failed. https://github.com/microsoft/onnxruntime/issues/16438",
905+
test_class_name="TestOutputConsistencyFullGraph",
906+
)
907+
.xfail(
866908
variant_name="with_dtype",
867909
dtypes=(torch.float16,),
868910
reason="fixme: ORT failed. https://github.com/microsoft/onnxruntime/issues/16438",
869911
test_class_name="TestOutputConsistencyFullGraph",
912+
)
913+
.skip(
914+
matcher=lambda sample: len(sample.input.shape) == 0,
915+
reason="fixme: LogSoftMax does not support empty tensor as input",
916+
)
917+
.skip(
918+
variant_name="with_dtype",
919+
matcher=lambda sample: len(sample.input.shape) == 0,
920+
reason="fixme: LogSoftMax does not support empty tensor as input",
870921
),
871922
TorchLibOpInfo("log2", core_ops.aten_log2),
872923
TorchLibOpInfo("logaddexp", core_ops.aten_logaddexp),
873924
TorchLibOpInfo("logaddexp2", core_ops.aten_logaddexp2),
874-
TorchLibOpInfo("logcumsumexp", core_ops.aten_logcumsumexp),
925+
TorchLibOpInfo(
926+
"logcumsumexp", core_ops.aten_logcumsumexp, tolerance={torch.float16: (1e-2, 1e-1)}
927+
),
875928
TorchLibOpInfo("logdet", core_ops.aten_logdet),
876929
TorchLibOpInfo("logsumexp", core_ops.aten_logsumexp),
877930
TorchLibOpInfo("lt", core_ops.aten_lt),
@@ -884,7 +937,7 @@ def _where_input_wrangler(
884937
"matmul",
885938
core_ops.aten_matmul,
886939
# Windows requires a more relaxed tolerance
887-
tolerance={torch.float32: (2e-5, 2e-5)},
940+
tolerance={torch.float32: (2e-5, 2e-5), torch.float16: (2e-3, 2e-2)},
888941
).skip(
889942
matcher=lambda sample: torch.numel(sample.input) == 0,
890943
reason="values of matmul of [m, 0] and [0, n] matrices are undefined",
@@ -1339,12 +1392,28 @@ def _where_input_wrangler(
13391392
TorchLibOpInfo(
13401393
"softmax",
13411394
core_ops.aten_softmax,
1395+
trace_only=True,
13421396
tolerance={torch.float32: (3.7e-5, 1.8e-4), torch.float16: (3e-4, 4e-4)},
1343-
).xfail(
1397+
)
1398+
.xfail(
1399+
dtypes=(torch.float16,),
1400+
reason="fixme: ORT failed. https://github.com/microsoft/onnxruntime/issues/16438",
1401+
test_class_name="TestOutputConsistencyFullGraph",
1402+
)
1403+
.xfail(
13441404
variant_name="with_dtype",
13451405
dtypes=(torch.float16,),
13461406
reason="fixme: ORT failed. https://github.com/microsoft/onnxruntime/issues/16438",
13471407
test_class_name="TestOutputConsistencyFullGraph",
1408+
)
1409+
.skip(
1410+
matcher=lambda sample: len(sample.input.shape) == 0,
1411+
reason="fixme: SoftMax does not support empty tensor as input",
1412+
)
1413+
.skip(
1414+
variant_name="with_dtype",
1415+
matcher=lambda sample: len(sample.input.shape) == 0,
1416+
reason="fixme: SoftMax does not support empty tensor as input",
13481417
),
13491418
TorchLibOpInfo("nn.functional.softplus", nn_ops.aten_softplus).xfail(
13501419
dtypes=(torch.float16,),
@@ -1700,7 +1769,12 @@ def _where_input_wrangler(
17001769
variant_name="empty_strides",
17011770
reason="fixme: 'shape' do not match: torch.Size([2, 3, 4, 3]) != torch.Size([2, 3, 4, 2]). https://github.com/microsoft/onnxscript/issues/975",
17021771
),
1703-
TorchLibOpInfo("native_batch_norm", core_ops.aten_native_batch_norm, trace_only=True),
1772+
TorchLibOpInfo(
1773+
"native_batch_norm",
1774+
core_ops.aten_native_batch_norm,
1775+
trace_only=True,
1776+
tolerance={torch.float16: (9e-3, 7e-4)},
1777+
),
17041778
TorchLibOpInfo(
17051779
"ops.aten._native_batch_norm_legit", core_ops.aten_native_batch_norm, trace_only=True
17061780
),
@@ -1719,9 +1793,11 @@ def _where_input_wrangler(
17191793
"ops.aten.native_group_norm",
17201794
core_ops.aten_native_group_norm,
17211795
trace_only=True,
1796+
tolerance={torch.float16: (1e-2, 7e-3)},
17221797
).xfail(
17231798
dtypes=(torch.float16,),
17241799
reason="fixme: 'GroupNormKernelImpl' not implemented for 'Half' in nightly and weekly",
1800+
enabled_if=version_utils.torch_older_than("2.2"),
17251801
),
17261802
TorchLibOpInfo(
17271803
"native_layer_norm",
@@ -1809,7 +1885,11 @@ def _where_input_wrangler(
18091885
matcher=lambda sample: len(sample.args) != 1,
18101886
reason="this overload is implemented for bias=None",
18111887
),
1812-
TorchLibOpInfo("nn.functional.linear_bias", nn_ops.aten_linear_bias).skip(
1888+
TorchLibOpInfo(
1889+
"nn.functional.linear_bias",
1890+
nn_ops.aten_linear_bias,
1891+
tolerance={torch.float16: (2e-1, 4e-4)},
1892+
).skip(
18131893
# input: input, args: weight, bias; so len(args) == 2 means bias is provided
18141894
matcher=lambda sample: len(sample.args) != 2,
18151895
reason="this overload is implemented for bias!=None",
@@ -2059,8 +2139,8 @@ def _where_input_wrangler(
20592139
TorchLibOpInfo("zeros_like", core_ops.aten_zeros_like, trace_only=True),
20602140
)
20612141

2062-
ops_test_common.duplicate_opinfo(OPS_DB, "all", ("all_dim",))
2063-
ops_test_common.duplicate_opinfo(OPS_DB, "any", ("any_dim",))
2142+
ops_test_common.duplicate_opinfo(OPS_DB, "all", ("all_dim", "all_dims"))
2143+
ops_test_common.duplicate_opinfo(OPS_DB, "any", ("any_dim", "any_dims"))
20642144
ops_test_common.duplicate_opinfo(OPS_DB, "arange", ("arange_start", "arange_start_step"))
20652145
ops_test_common.duplicate_opinfo(OPS_DB, "argmax", ("argmax_dim",))
20662146
ops_test_common.duplicate_opinfo(OPS_DB, "argmin", ("argmin_dim",))

0 commit comments

Comments
 (0)