Skip to content

Create env to test with TORCHLIB_EXPERIMENTAL_PREFER_TRACING on | test(torchlib) #1180

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 16 commits into from
Nov 23, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
16 commits
Select commit Hold shift + click to select a range
75551b5
Define the EXPERIMENTAL_PREFER_TRACING flag
justinchuby Nov 22, 2023
b513a23
Create env to test with TORCHLIB_EXPERIMENTAL_PREFER_TRACING on | tes…
justinchuby Nov 22, 2023
581e0b4
Update on "Create env to test with TORCHLIB_EXPERIMENTAL_PREFER_TRACI…
justinchuby Nov 22, 2023
ce12b39
Update on "Create env to test with TORCHLIB_EXPERIMENTAL_PREFER_TRACI…
justinchuby Nov 22, 2023
e47392f
Update base for Update on "Create env to test with TORCHLIB_EXPERIMEN…
justinchuby Nov 22, 2023
3bec3f8
Update on "Create env to test with TORCHLIB_EXPERIMENTAL_PREFER_TRACI…
justinchuby Nov 22, 2023
1d83826
Update base for Update on "Create env to test with TORCHLIB_EXPERIMEN…
justinchuby Nov 22, 2023
0e1b0e6
Update on "Create env to test with TORCHLIB_EXPERIMENTAL_PREFER_TRACI…
justinchuby Nov 22, 2023
ae2388e
Update base for Update on "Create env to test with TORCHLIB_EXPERIMEN…
justinchuby Nov 23, 2023
ffae748
Update on "Create env to test with TORCHLIB_EXPERIMENTAL_PREFER_TRACI…
justinchuby Nov 23, 2023
599b841
Update base for Update on "Create env to test with TORCHLIB_EXPERIMEN…
justinchuby Nov 23, 2023
c8692c3
Update on "Create env to test with TORCHLIB_EXPERIMENTAL_PREFER_TRACI…
justinchuby Nov 23, 2023
3c2afff
Update base for Update on "Create env to test with TORCHLIB_EXPERIMEN…
justinchuby Nov 23, 2023
8c1525e
Update on "Create env to test with TORCHLIB_EXPERIMENTAL_PREFER_TRACI…
justinchuby Nov 23, 2023
ad0245f
Update base for Update on "Create env to test with TORCHLIB_EXPERIMEN…
justinchuby Nov 23, 2023
55bdb27
Update on "Create env to test with TORCHLIB_EXPERIMENTAL_PREFER_TRACI…
justinchuby Nov 23, 2023
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions .github/workflows/main.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,8 @@ jobs:
- py310-torch-nightly
- py310-onnx-weekly
- py310-ort-nightly
- py311-ort-nightly
- py310-experimental-torchlib-tracing
include:
- name: py310
python-version: "3.10"
Expand All @@ -50,6 +52,12 @@ jobs:
- name: py310-ort-nightly
python-version: "3.10"
nox-tag: test-ort-nightly
- name: py311-ort-nightly
python-version: "3.11"
nox-tag: test-ort-nightly
- name: py310-experimental-torchlib-tracing
python-version: "3.10"
nox-tag: test-experimental-torchlib-tracing
runs-on: ${{ matrix.os }}
steps:
- uses: actions/checkout@v4
Expand Down
19 changes: 18 additions & 1 deletion noxfile.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@

COMMON_TEST_DEPENDENCIES = (
"jinja2",
"numpy==1.23.5",
"numpy==1.24.4",
"typing_extensions",
"beartype!=0.16.0",
"types-PyYAML",
Expand Down Expand Up @@ -95,3 +95,20 @@ def test_ort_nightly(session):
session.install(".", "--no-deps")
session.run("pip", "list")
session.run("pytest", "onnxscript", *session.posargs)


@nox.session(tags=["test-experimental-torchlib-tracing"])
def test_experimental_torchlib_tracing(session):
"""Test TorchLib with the experimental TORCHLIB_EXPERIMENTAL_PREFER_TRACING flag on."""
session.install(
*COMMON_TEST_DEPENDENCIES, PYTORCH, ONNX, *ONNX_RUNTIME_NIGHTLY_DEPENDENCIES
)
session.install("-r", "requirements/ci/requirements-ort-nightly.txt")
session.install(".", "--no-deps")
session.run("pip", "list")
session.run(
"pytest",
"onnxscript/tests/function_libs/torch_lib/ops_test.py",
*session.posargs,
env={"TORCHLIB_EXPERIMENTAL_PREFER_TRACING": "1"},
)
17 changes: 10 additions & 7 deletions onnxscript/function_libs/torch_lib/ops/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -365,7 +365,9 @@ def aten_all_dims(self: TTensor, dim: Sequence[int] = (), keepdim: bool = False)
if not dim:
return aten_all_dims_no_dim(self, keepdim)
for d in dim:
self = aten_all_dim(self, d, keepdim)
self = aten_all_dim(self, d, keepdim=True)
if not keepdim:
self = op.Squeeze(self, list(dim))
return self


Expand Down Expand Up @@ -488,7 +490,9 @@ def aten_any_dims(self: TTensor, dim: Sequence[int] = (), keepdim: bool = False)
if not dim:
return aten_any_dims_no_dim(self, keepdim)
for d in dim:
self = aten_any_dim(self, d, keepdim)
self = aten_any_dim(self, d, keepdim=True)
if not keepdim:
self = op.Squeeze(self, list(dim))
return self


Expand Down Expand Up @@ -7339,17 +7343,16 @@ def aten_smm(self: TensorType, mat2: TensorType) -> TensorType:
raise NotImplementedError()


@torch_op(("aten::softmax", "aten::softmax.int", "aten::special_softmax"))
def aten_softmax(
self: TFloatOrBFloat16, dim: int, dtype: int = FLOAT.dtype
) -> TFloatOrBFloat16:
@torch_op(("aten::softmax", "aten::softmax.int", "aten::special_softmax"), trace_only=True)
def aten_softmax(self: TFloatOrBFloat16, dim: int, dtype: int = -1) -> TFloatOrBFloat16:
"""softmax(Tensor self, int dim, ScalarType? dtype=None) -> Tensor"""

self_is_scalar = IsScalar(self)
if self_is_scalar:
self = op.Unsqueeze(self, op.Constant(value_ints=[0]))
result = op.Softmax(self, axis=dim)
result = op.Cast(result, to=dtype)
if dtype != -1:
result = op.Cast(result, to=dtype)
if self_is_scalar:
# Convert to scalar when input is scalar
result = op.Squeeze(result)
Expand Down
8 changes: 4 additions & 4 deletions onnxscript/function_libs/torch_lib/ops/special.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,6 @@

from typing import Optional, Sequence

from onnxscript import FLOAT
from onnxscript.function_libs.torch_lib.ops import common as common_ops
from onnxscript.function_libs.torch_lib.registration import torch_op
from onnxscript.function_libs.torch_lib.tensor_typing import TFloatOrBFloat16
Expand Down Expand Up @@ -212,17 +211,18 @@ def aten_special_log_ndtr(self: TensorType) -> TensorType:
raise NotImplementedError()


@torch_op(("aten::log_softmax", "aten::special_log_softmax"))
@torch_op(("aten::log_softmax", "aten::special_log_softmax"), trace_only=True)
def aten_special_log_softmax(
self: TFloatOrBFloat16, dim: int, dtype: int = FLOAT.dtype
self: TFloatOrBFloat16, dim: int, dtype: int = -1
) -> TFloatOrBFloat16:
"""special_log_softmax(Tensor self, int dim, *, ScalarType? dtype=None) -> Tensor"""

self_is_scalar = IsScalar(self)
if self_is_scalar:
self = op.Unsqueeze(self, op.Constant(value_ints=[0]))
result = op.LogSoftmax(self, axis=dim)
result = op.Cast(result, to=dtype)
if dtype != -1:
result = op.Cast(result, to=dtype)
if self_is_scalar: # squeeze to scalar due to input is scalar
result = op.Squeeze(result)
return result
Expand Down
4 changes: 4 additions & 0 deletions onnxscript/tests/function_libs/torch_lib/ops_test_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -161,6 +161,10 @@ def add_decorate_info(
ops_mapping = {(info.name, info.variant_test_name): info for info in all_opinfos}
for decorate_meta in skip_or_xfails:
opinfo = ops_mapping.get((decorate_meta.op_name, decorate_meta.variant_name))
if opinfo is None and not decorate_meta.enabled_if:
# If the OpInfo doesn't exist and it is not enabled, we skip the OpInfo
# because it could be an OpInfo that is in torch-nightly but not older versions.
continue
assert (
opinfo is not None
), f"Couldn't find OpInfo for {decorate_meta}. Did you need to specify variant_name?"
Expand Down
122 changes: 101 additions & 21 deletions onnxscript/tests/function_libs/torch_lib/ops_test_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -471,21 +471,41 @@ def _where_input_wrangler(
),
TorchLibOpInfo("ops.aten._log_softmax", core_ops.aten__log_softmax),
TorchLibOpInfo(
"ops.aten._log_softmax_half", core_ops.aten__log_softmax_half, trace_only=True
).xfail(
"ops.aten._log_softmax_half",
core_ops.aten__log_softmax_half,
trace_only=True,
tolerance={torch.float16: (1e-3, 1e-3)},
)
.xfail(
reason="PyTorch does not implement _log_softmax for float16 on CPU",
dtypes=(torch.float16,),
enabled_if=version_utils.torch_older_than("2.2"),
)
.xfail(
dtypes=(torch.float16,),
reason="fixme: ORT failed. https://github.com/microsoft/onnxruntime/issues/16438",
test_class_name="TestOutputConsistencyFullGraph",
),
TorchLibOpInfo("ops.aten._softmax", core_ops.aten__softmax, trace_only=True),
TorchLibOpInfo(
"ops.aten._softmax_half", core_ops.aten__softmax_half, trace_only=True
).xfail(
TorchLibOpInfo("ops.aten._softmax_half", core_ops.aten__softmax_half, trace_only=True)
.xfail(
reason="PyTorch does not implement _softmax for float16 on CPU",
dtypes=(torch.float16,),
enabled_if=version_utils.torch_older_than("2.2"),
)
.xfail(
dtypes=(torch.float16,),
reason="fixme: ORT failed. https://github.com/microsoft/onnxruntime/issues/16438",
test_class_name="TestOutputConsistencyFullGraph",
),
TorchLibOpInfo("all_dim", core_ops.aten_all_dim).skip(
matcher=lambda sample: not (len(sample.kwargs) > 0)
or isinstance(sample.kwargs.get("dim"), tuple),
reason="this Aten overload only support one tensor as input and {dim,keepdim} as kwargs by design. dim must be an integer",
),
TorchLibOpInfo("all_dim", core_ops.aten_all_dim).xfail(
matcher=lambda sample: not (len(sample.kwargs) > 0),
reason="this Aten overload only support one tensor as input and {dim,keepdim} as kwargs by design",
TorchLibOpInfo("all_dims", core_ops.aten_all_dims, trace_only=True).skip(
matcher=lambda sample: not isinstance(sample.kwargs.get("dim"), tuple),
reason="this overload requires dim to be a tuple",
),
TorchLibOpInfo("allclose", core_ops.aten_allclose),
TorchLibOpInfo(
Expand All @@ -501,7 +521,11 @@ def _where_input_wrangler(
TorchLibOpInfo("acosh", core_ops.aten_acosh),
TorchLibOpInfo("add", core_ops.aten_add, tolerance={torch.float16: (1e-3, 1e-3)}),
TorchLibOpInfo("add", core_ops.aten_add_complex, complex=True, trace_only=True),
TorchLibOpInfo("addbmm", core_ops.aten_addbmm, tolerance={torch.float32: (2e-5, 2e-5)}),
TorchLibOpInfo(
"addbmm",
core_ops.aten_addbmm,
tolerance={torch.float32: (2e-5, 2e-5), torch.float16: (2e-1, 2e-2)},
),
TorchLibOpInfo("addcdiv", core_ops.aten_addcdiv),
TorchLibOpInfo("addcmul", core_ops.aten_addcmul, tolerance={torch.float16: (4e-3, 3e-3)}),
TorchLibOpInfo("addmm", core_ops.aten_addmm)
Expand All @@ -522,7 +546,7 @@ def _where_input_wrangler(
dtypes=(torch.int16, torch.int32, torch.int64),
reason="ONNX Runtime does not support int inputs to Gemm",
),
TorchLibOpInfo("addmv", core_ops.aten_addmv),
TorchLibOpInfo("addmv", core_ops.aten_addmv, tolerance={torch.float16: (1e-3, 1e-2)}),
TorchLibOpInfo(
"addr",
core_ops.aten_addr,
Expand Down Expand Up @@ -557,8 +581,13 @@ def _where_input_wrangler(
"any_dim",
core_ops.aten_any_dim,
).skip(
matcher=lambda sample: not (len(sample.kwargs) > 0),
reason="this Aten overload only support one tensor as input and {dim,keepdim} as kwargs by design",
matcher=lambda sample: not (len(sample.kwargs) > 0)
or isinstance(sample.kwargs.get("dim"), tuple),
reason="this Aten overload only support one tensor as input and {dim,keepdim} as kwargs by design. dim must be an integer",
),
TorchLibOpInfo("any_dims", core_ops.aten_any_dims, trace_only=True).skip(
matcher=lambda sample: not isinstance(sample.kwargs.get("dim"), tuple),
reason="this overload requires dim to be a tuple",
),
TorchLibOpInfo("asin", core_ops.aten_asin),
TorchLibOpInfo("asinh", core_ops.aten_asinh),
Expand Down Expand Up @@ -640,7 +669,7 @@ def _where_input_wrangler(
"https://github.com/microsoft/onnxscript/issues/1007"
),
),
TorchLibOpInfo("baddbmm", core_ops.aten_baddbmm),
TorchLibOpInfo("baddbmm", core_ops.aten_baddbmm, tolerance={torch.float16: (1e-3, 1e-2)}),
TorchLibOpInfo("bernoulli", core_ops.aten_bernoulli, nondeterministic=True),
TorchLibOpInfo(
# This string is a unique ID. In extra_opinfo.py, we
Expand Down Expand Up @@ -845,6 +874,12 @@ def _where_input_wrangler(
dtypes=(torch.int64, torch.int32),
reason="fixme: Results do not match with PyTorch. https://github.com/microsoft/onnxscript/issues/854",
)
.xfail(
variant_name="tensor_overload",
dtypes=(torch.int64, torch.int32, torch.float16),
reason="fixme: Results do not match with PyTorch. https://github.com/microsoft/onnxscript/issues/854",
enabled_if=not version_utils.torch_older_than("2.2"),
)
.xfail(
dtypes=(torch.float16,),
reason="op 'Range' doesn't support float16.",
Expand All @@ -861,17 +896,35 @@ def _where_input_wrangler(
TorchLibOpInfo(
"log_softmax",
special_ops.aten_special_log_softmax,
trace_only=True,
tolerance={torch.float32: (3.7e-5, 1.8e-4), torch.float16: (4e-4, 6e-3)},
).xfail(
)
.xfail(
dtypes=(torch.float16,),
reason="fixme: ORT failed. https://github.com/microsoft/onnxruntime/issues/16438",
test_class_name="TestOutputConsistencyFullGraph",
)
.xfail(
variant_name="with_dtype",
dtypes=(torch.float16,),
reason="fixme: ORT failed. https://github.com/microsoft/onnxruntime/issues/16438",
test_class_name="TestOutputConsistencyFullGraph",
)
.skip(
matcher=lambda sample: len(sample.input.shape) == 0,
reason="fixme: LogSoftMax does not support empty tensor as input",
)
.skip(
variant_name="with_dtype",
matcher=lambda sample: len(sample.input.shape) == 0,
reason="fixme: LogSoftMax does not support empty tensor as input",
),
TorchLibOpInfo("log2", core_ops.aten_log2),
TorchLibOpInfo("logaddexp", core_ops.aten_logaddexp),
TorchLibOpInfo("logaddexp2", core_ops.aten_logaddexp2),
TorchLibOpInfo("logcumsumexp", core_ops.aten_logcumsumexp),
TorchLibOpInfo(
"logcumsumexp", core_ops.aten_logcumsumexp, tolerance={torch.float16: (1e-2, 1e-1)}
),
TorchLibOpInfo("logdet", core_ops.aten_logdet),
TorchLibOpInfo("logsumexp", core_ops.aten_logsumexp),
TorchLibOpInfo("lt", core_ops.aten_lt),
Expand All @@ -884,7 +937,7 @@ def _where_input_wrangler(
"matmul",
core_ops.aten_matmul,
# Windows requires a more relaxed tolerance
tolerance={torch.float32: (2e-5, 2e-5)},
tolerance={torch.float32: (2e-5, 2e-5), torch.float16: (2e-3, 2e-2)},
).skip(
matcher=lambda sample: torch.numel(sample.input) == 0,
reason="values of matmul of [m, 0] and [0, n] matrices are undefined",
Expand Down Expand Up @@ -1339,12 +1392,28 @@ def _where_input_wrangler(
TorchLibOpInfo(
"softmax",
core_ops.aten_softmax,
trace_only=True,
tolerance={torch.float32: (3.7e-5, 1.8e-4), torch.float16: (3e-4, 4e-4)},
).xfail(
)
.xfail(
dtypes=(torch.float16,),
reason="fixme: ORT failed. https://github.com/microsoft/onnxruntime/issues/16438",
test_class_name="TestOutputConsistencyFullGraph",
)
.xfail(
variant_name="with_dtype",
dtypes=(torch.float16,),
reason="fixme: ORT failed. https://github.com/microsoft/onnxruntime/issues/16438",
test_class_name="TestOutputConsistencyFullGraph",
)
.skip(
matcher=lambda sample: len(sample.input.shape) == 0,
reason="fixme: SoftMax does not support empty tensor as input",
)
.skip(
variant_name="with_dtype",
matcher=lambda sample: len(sample.input.shape) == 0,
reason="fixme: SoftMax does not support empty tensor as input",
),
TorchLibOpInfo("nn.functional.softplus", nn_ops.aten_softplus).xfail(
dtypes=(torch.float16,),
Expand Down Expand Up @@ -1700,7 +1769,12 @@ def _where_input_wrangler(
variant_name="empty_strides",
reason="fixme: 'shape' do not match: torch.Size([2, 3, 4, 3]) != torch.Size([2, 3, 4, 2]). https://github.com/microsoft/onnxscript/issues/975",
),
TorchLibOpInfo("native_batch_norm", core_ops.aten_native_batch_norm, trace_only=True),
TorchLibOpInfo(
"native_batch_norm",
core_ops.aten_native_batch_norm,
trace_only=True,
tolerance={torch.float16: (9e-3, 7e-4)},
),
TorchLibOpInfo(
"ops.aten._native_batch_norm_legit", core_ops.aten_native_batch_norm, trace_only=True
),
Expand All @@ -1719,9 +1793,11 @@ def _where_input_wrangler(
"ops.aten.native_group_norm",
core_ops.aten_native_group_norm,
trace_only=True,
tolerance={torch.float16: (1e-2, 7e-3)},
).xfail(
dtypes=(torch.float16,),
reason="fixme: 'GroupNormKernelImpl' not implemented for 'Half' in nightly and weekly",
enabled_if=version_utils.torch_older_than("2.2"),
),
TorchLibOpInfo(
"native_layer_norm",
Expand Down Expand Up @@ -1809,7 +1885,11 @@ def _where_input_wrangler(
matcher=lambda sample: len(sample.args) != 1,
reason="this overload is implemented for bias=None",
),
TorchLibOpInfo("nn.functional.linear_bias", nn_ops.aten_linear_bias).skip(
TorchLibOpInfo(
"nn.functional.linear_bias",
nn_ops.aten_linear_bias,
tolerance={torch.float16: (2e-1, 4e-4)},
).skip(
# input: input, args: weight, bias; so len(args) == 2 means bias is provided
matcher=lambda sample: len(sample.args) != 2,
reason="this overload is implemented for bias!=None",
Expand Down Expand Up @@ -2059,8 +2139,8 @@ def _where_input_wrangler(
TorchLibOpInfo("zeros_like", core_ops.aten_zeros_like, trace_only=True),
)

ops_test_common.duplicate_opinfo(OPS_DB, "all", ("all_dim",))
ops_test_common.duplicate_opinfo(OPS_DB, "any", ("any_dim",))
ops_test_common.duplicate_opinfo(OPS_DB, "all", ("all_dim", "all_dims"))
ops_test_common.duplicate_opinfo(OPS_DB, "any", ("any_dim", "any_dims"))
ops_test_common.duplicate_opinfo(OPS_DB, "arange", ("arange_start", "arange_start_step"))
ops_test_common.duplicate_opinfo(OPS_DB, "argmax", ("argmax_dim",))
ops_test_common.duplicate_opinfo(OPS_DB, "argmin", ("argmin_dim",))
Expand Down