Skip to content

Commit 688f677

Browse files
committed
Update on "Define the EXPERIMENTAL_PREFER_TRACING flag and the traceable option | feat(torchlib)"
As an effort described in #1095, this PR - adds an experimental `TORCHLIB_EXPERIMENTAL_PREFER_TRACING` flag to allow the tracer to trace a function when possible. - defined the `traceable` option in the torch_op decorator to mark a function as `traceable`. [ghstack-poisoned]
2 parents 78d642a + 7f2210f commit 688f677

File tree

3 files changed

+62
-16
lines changed

3 files changed

+62
-16
lines changed

onnxscript/function_libs/torch_lib/ops/core.py

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -7343,17 +7343,16 @@ def aten_smm(self: TensorType, mat2: TensorType) -> TensorType:
73437343
raise NotImplementedError()
73447344

73457345

7346-
@torch_op(("aten::softmax", "aten::softmax.int", "aten::special_softmax"))
7347-
def aten_softmax(
7348-
self: TFloatOrBFloat16, dim: int, dtype: int = FLOAT.dtype
7349-
) -> TFloatOrBFloat16:
7346+
@torch_op(("aten::softmax", "aten::softmax.int", "aten::special_softmax"), trace_only=True)
7347+
def aten_softmax(self: TFloatOrBFloat16, dim: int, dtype: int = -1) -> TFloatOrBFloat16:
73507348
"""softmax(Tensor self, int dim, ScalarType? dtype=None) -> Tensor"""
73517349

73527350
self_is_scalar = IsScalar(self)
73537351
if self_is_scalar:
73547352
self = op.Unsqueeze(self, op.Constant(value_ints=[0]))
73557353
result = op.Softmax(self, axis=dim)
7356-
result = op.Cast(result, to=dtype)
7354+
if dtype != -1:
7355+
result = op.Cast(result, to=dtype)
73577356
if self_is_scalar:
73587357
# Convert to scalar when input is scalar
73597358
result = op.Squeeze(result)

onnxscript/function_libs/torch_lib/ops/special.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,6 @@
1313

1414
from typing import Optional, Sequence
1515

16-
from onnxscript import FLOAT
1716
from onnxscript.function_libs.torch_lib.ops import common as common_ops
1817
from onnxscript.function_libs.torch_lib.registration import torch_op
1918
from onnxscript.function_libs.torch_lib.tensor_typing import TFloatOrBFloat16
@@ -212,17 +211,18 @@ def aten_special_log_ndtr(self: TensorType) -> TensorType:
212211
raise NotImplementedError()
213212

214213

215-
@torch_op(("aten::log_softmax", "aten::special_log_softmax"))
214+
@torch_op(("aten::log_softmax", "aten::special_log_softmax"), trace_only=True)
216215
def aten_special_log_softmax(
217-
self: TFloatOrBFloat16, dim: int, dtype: int = FLOAT.dtype
216+
self: TFloatOrBFloat16, dim: int, dtype: int = -1
218217
) -> TFloatOrBFloat16:
219218
"""special_log_softmax(Tensor self, int dim, *, ScalarType? dtype=None) -> Tensor"""
220219

221220
self_is_scalar = IsScalar(self)
222221
if self_is_scalar:
223222
self = op.Unsqueeze(self, op.Constant(value_ints=[0]))
224223
result = op.LogSoftmax(self, axis=dim)
225-
result = op.Cast(result, to=dtype)
224+
if dtype != -1:
225+
result = op.Cast(result, to=dtype)
226226
if self_is_scalar: # squeeze to scalar due to input is scalar
227227
result = op.Squeeze(result)
228228
return result

onnxscript/tests/function_libs/torch_lib/ops_test_data.py

Lines changed: 54 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -471,17 +471,32 @@ def _where_input_wrangler(
471471
),
472472
TorchLibOpInfo("ops.aten._log_softmax", core_ops.aten__log_softmax),
473473
TorchLibOpInfo(
474-
"ops.aten._log_softmax_half", core_ops.aten__log_softmax_half, trace_only=True
475-
).xfail(
474+
"ops.aten._log_softmax_half",
475+
core_ops.aten__log_softmax_half,
476+
trace_only=True,
477+
tolerance={torch.float16: (1e-3, 1e-3)},
478+
)
479+
.xfail(
476480
reason="PyTorch does not implement _log_softmax for float16 on CPU",
477481
dtypes=(torch.float16,),
482+
enabled_if=version_utils.torch_older_than("2.2"),
483+
)
484+
.xfail(
485+
dtypes=(torch.float16,),
486+
reason="fixme: ORT failed. https://github.com/microsoft/onnxruntime/issues/16438",
487+
test_class_name="TestOutputConsistencyFullGraph",
478488
),
479489
TorchLibOpInfo("ops.aten._softmax", core_ops.aten__softmax, trace_only=True),
480-
TorchLibOpInfo(
481-
"ops.aten._softmax_half", core_ops.aten__softmax_half, trace_only=True
482-
).xfail(
490+
TorchLibOpInfo("ops.aten._softmax_half", core_ops.aten__softmax_half, trace_only=True)
491+
.xfail(
483492
reason="PyTorch does not implement _softmax for float16 on CPU",
484493
dtypes=(torch.float16,),
494+
enabled_if=version_utils.torch_older_than("2.2"),
495+
)
496+
.xfail(
497+
dtypes=(torch.float16,),
498+
reason="fixme: ORT failed. https://github.com/microsoft/onnxruntime/issues/16438",
499+
test_class_name="TestOutputConsistencyFullGraph",
485500
),
486501
TorchLibOpInfo("all_dim", core_ops.aten_all_dim).skip(
487502
matcher=lambda sample: not (len(sample.kwargs) > 0)
@@ -881,12 +896,28 @@ def _where_input_wrangler(
881896
TorchLibOpInfo(
882897
"log_softmax",
883898
special_ops.aten_special_log_softmax,
899+
trace_only=True,
884900
tolerance={torch.float32: (3.7e-5, 1.8e-4), torch.float16: (4e-4, 6e-3)},
885-
).xfail(
901+
)
902+
.xfail(
903+
dtypes=(torch.float16,),
904+
reason="fixme: ORT failed. https://github.com/microsoft/onnxruntime/issues/16438",
905+
test_class_name="TestOutputConsistencyFullGraph",
906+
)
907+
.xfail(
886908
variant_name="with_dtype",
887909
dtypes=(torch.float16,),
888910
reason="fixme: ORT failed. https://github.com/microsoft/onnxruntime/issues/16438",
889911
test_class_name="TestOutputConsistencyFullGraph",
912+
)
913+
.skip(
914+
matcher=lambda sample: len(sample.input.shape) == 0,
915+
reason="fixme: LogSoftMax does not support empty tensor as input",
916+
)
917+
.skip(
918+
variant_name="with_dtype",
919+
matcher=lambda sample: len(sample.input.shape) == 0,
920+
reason="fixme: LogSoftMax does not support empty tensor as input",
890921
),
891922
TorchLibOpInfo("log2", core_ops.aten_log2),
892923
TorchLibOpInfo("logaddexp", core_ops.aten_logaddexp),
@@ -1361,12 +1392,28 @@ def _where_input_wrangler(
13611392
TorchLibOpInfo(
13621393
"softmax",
13631394
core_ops.aten_softmax,
1395+
trace_only=True,
13641396
tolerance={torch.float32: (3.7e-5, 1.8e-4), torch.float16: (3e-4, 4e-4)},
1365-
).xfail(
1397+
)
1398+
.xfail(
1399+
dtypes=(torch.float16,),
1400+
reason="fixme: ORT failed. https://github.com/microsoft/onnxruntime/issues/16438",
1401+
test_class_name="TestOutputConsistencyFullGraph",
1402+
)
1403+
.xfail(
13661404
variant_name="with_dtype",
13671405
dtypes=(torch.float16,),
13681406
reason="fixme: ORT failed. https://github.com/microsoft/onnxruntime/issues/16438",
13691407
test_class_name="TestOutputConsistencyFullGraph",
1408+
)
1409+
.skip(
1410+
matcher=lambda sample: len(sample.input.shape) == 0,
1411+
reason="fixme: SoftMax does not support empty tensor as input",
1412+
)
1413+
.skip(
1414+
variant_name="with_dtype",
1415+
matcher=lambda sample: len(sample.input.shape) == 0,
1416+
reason="fixme: SoftMax does not support empty tensor as input",
13701417
),
13711418
TorchLibOpInfo("nn.functional.softplus", nn_ops.aten_softplus).xfail(
13721419
dtypes=(torch.float16,),

0 commit comments

Comments
 (0)