Skip to content

Commit b48a098

Browse files
committed
Update on "Define the EXPERIMENTAL_PREFER_TRACING flag and the traceable option | feat(torchlib)"
As an effort described in #1095, this PR - adds an experimental `TORCHLIB_EXPERIMENTAL_PREFER_TRACING` flag to allow the tracer to trace a function when possible. - defined the `traceable` option in the torch_op decorator to mark a function as `traceable`. [ghstack-poisoned]
2 parents 688f677 + eeb1ff7 commit b48a098

File tree

5 files changed

+60
-15
lines changed

5 files changed

+60
-15
lines changed

onnxscript/function_libs/torch_lib/graph_building.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -816,6 +816,7 @@ def generate_function_value_info_proto(
816816
continue
817817
if prefix:
818818
name = f"{prefix}/{name}"
819+
value_info.name = name
819820
named_value_info[name] = value_info
820821
for name, sub_graph in self._sub_torch_script_graphs.items():
821822
named_value_info.update(

onnxscript/function_libs/torch_lib/ops/core.py

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -4576,9 +4576,10 @@ def aten_logaddexp(self: TFloatOrBFloat16, other: TFloatOrBFloat16) -> TFloatOrB
45764576
@torch_op("aten::logaddexp2")
45774577
def aten_logaddexp2(self: TFloatOrBFloat16, other: TFloatOrBFloat16) -> TFloatOrBFloat16:
45784578
"""logaddexp2(Tensor self, Tensor other) -> Tensor"""
4579-
summation = op.Add(op.Pow(2.0, self), op.Pow(2.0, other))
4579+
two = op.CastLike(2.0, self)
4580+
summation = op.Add(op.Pow(two, self), op.Pow(two, other))
45804581

4581-
return op.Div(op.Log(summation), op.CastLike(op.Log(2.0), self))
4582+
return op.Div(op.Log(summation), op.Log(two))
45824583

45834584

45844585
@torch_op("aten::logcumsumexp")
@@ -4673,10 +4674,12 @@ def _aten_logit_onnx(self: TFloatOrBFloat16) -> TFloatOrBFloat16:
46734674

46744675
@torch_op("aten::logit", private=True)
46754676
def _aten_logit_clamp_onnx(self: TFloatOrBFloat16, eps: float) -> TFloatOrBFloat16:
4676-
temporary_self = op.Where(self <= 1.0 - eps, self, 1.0 - eps)
4677+
eps = op.CastLike(eps, self)
4678+
one = op.CastLike(1.0, self)
4679+
temporary_self = op.Where(self <= one - eps, self, one - eps)
46774680
z = op.Where(temporary_self < eps, eps, temporary_self)
46784681

4679-
return op.Log(op.Div(z, op.Sub(1.0, z)))
4682+
return op.Log(op.Div(z, op.Sub(one, z)))
46804683

46814684

46824685
@torch_op("aten::logit", trace_only=True)

onnxscript/function_libs/torch_lib/registration.py

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -113,7 +113,18 @@ def torch_op(
113113
private: Whether the function is private (not directly exposed). It should
114114
be true for all functions with names starting with "_".
115115
complex: Whether the function expects complex-valued inputs.
116-
traceable: Whether the function can be traced.
116+
traceable: Whether the function can also be traced. This is an **experimental** flag.
117+
A function is traceable if it can both be scripted and traced to produce
118+
the same result for a given input. Specifically:
119+
120+
- A function _can_ be tagged with traceable if its if branches (if any)
121+
can be statically evaluated.
122+
- A function _should_ be tagged with traceable if it contains if branches
123+
and/or CastLike nodes so that they can be evaluated away with the
124+
EXPERIMENTAL_PREFER_TRACING on.
125+
- A function without if branches or CastLike nodes _should not_ be tagged
126+
with traceable because inlining will do the same thing.
127+
- A function with `@graph` defined for a `Scan` op is not traceable yet.
117128
"""
118129
if registry is None:
119130
registry = default_registry

onnxscript/tests/function_libs/torch_lib/extra_opinfo.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -69,7 +69,7 @@ def sample_inputs_conv3d(op_info, device, dtype, requires_grad, **kwargs):
6969
(32,),
7070
{
7171
"stride": (3, 3, 3),
72-
"padding": 2,
72+
"padding": (2, 2, 2),
7373
"dilation": (1, 1, 1),
7474
"groups": 1,
7575
},
@@ -1394,7 +1394,7 @@ def sample_inputs__native_batch_norm_legit_no_stats(
13941394
supports_out=False,
13951395
),
13961396
opinfo_core.OpInfo(
1397-
"nn.functional.conv3d",
1397+
"ops.aten.conv3d",
13981398
aten_name="conv3d",
13991399
dtypes=common_dtype.floating_and_complex_types_and(torch.int64, torch.bfloat16),
14001400
sample_inputs_func=sample_inputs_conv3d,

onnxscript/tests/function_libs/torch_lib/ops_test_data.py

Lines changed: 38 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -526,7 +526,7 @@ def _where_input_wrangler(
526526
core_ops.aten_addbmm,
527527
tolerance={torch.float32: (2e-5, 2e-5), torch.float16: (2e-1, 2e-2)},
528528
),
529-
TorchLibOpInfo("addcdiv", core_ops.aten_addcdiv),
529+
TorchLibOpInfo("addcdiv", core_ops.aten_addcdiv, tolerance={torch.float16: (3e-2, 1e-3)}),
530530
TorchLibOpInfo("addcmul", core_ops.aten_addcmul, tolerance={torch.float16: (4e-3, 3e-3)}),
531531
TorchLibOpInfo("addmm", core_ops.aten_addmm)
532532
.xfail(
@@ -592,7 +592,7 @@ def _where_input_wrangler(
592592
TorchLibOpInfo("asin", core_ops.aten_asin),
593593
TorchLibOpInfo("asinh", core_ops.aten_asinh),
594594
TorchLibOpInfo("atan", core_ops.aten_atan),
595-
TorchLibOpInfo("atan2", core_ops.aten_atan2),
595+
TorchLibOpInfo("atan2", core_ops.aten_atan2, tolerance={torch.float16: (1e-3, 1e-3)}),
596596
TorchLibOpInfo("atanh", core_ops.aten_atanh),
597597
TorchLibOpInfo("atleast_1d", core_ops.aten_atleast_1d).skip(
598598
matcher=lambda sample: isinstance(sample.input, (list, tuple)),
@@ -737,7 +737,7 @@ def _where_input_wrangler(
737737
# TorchLibOpInfo("copy", core_ops.aten_copy), # copy is not in OPS_DB
738738
TorchLibOpInfo("cos", core_ops.aten_cos),
739739
TorchLibOpInfo("cosh", core_ops.aten_cosh),
740-
TorchLibOpInfo("cross", core_ops.aten_cross),
740+
TorchLibOpInfo("cross", core_ops.aten_cross, tolerance={torch.float16: (6e-3, 3e-3)}),
741741
# TorchLibOpInfo("detach", core_ops.aten_detach), # detach is not in OP-TEST-DB
742742
TorchLibOpInfo("diagonal", core_ops.aten_diagonal, trace_only=True),
743743
TorchLibOpInfo("diagonal_bool", core_ops.aten_diagonal_bool, trace_only=True),
@@ -920,8 +920,10 @@ def _where_input_wrangler(
920920
reason="fixme: LogSoftMax does not support empty tensor as input",
921921
),
922922
TorchLibOpInfo("log2", core_ops.aten_log2),
923-
TorchLibOpInfo("logaddexp", core_ops.aten_logaddexp),
924-
TorchLibOpInfo("logaddexp2", core_ops.aten_logaddexp2),
923+
TorchLibOpInfo("logaddexp", core_ops.aten_logaddexp, tolerance={torch.float16: (1, 1e-4)}),
924+
TorchLibOpInfo(
925+
"logaddexp2", core_ops.aten_logaddexp2, tolerance={torch.float16: (2e-2, 6e-4)}
926+
),
925927
TorchLibOpInfo(
926928
"logcumsumexp", core_ops.aten_logcumsumexp, tolerance={torch.float16: (1e-2, 1e-1)}
927929
),
@@ -1087,10 +1089,16 @@ def _where_input_wrangler(
10871089
TorchLibOpInfo(
10881090
"nn.functional.adaptive_avg_pool1d",
10891091
nn_ops.aten_adaptive_avg_pool1d,
1090-
).xfail(
1092+
)
1093+
.xfail(
10911094
# Shape should be [N, C, D1]
10921095
matcher=lambda sample: sample.args[0] not in {1, (1,)},
10931096
reason="only global pooling is supported; only batched inputs are supported",
1097+
)
1098+
.xfail(
1099+
reason="ORT fails on a cast node it inserts for float16. https://github.com/microsoft/onnxruntime/issues/16449",
1100+
dtypes=(torch.float16,),
1101+
test_class_name="TestOutputConsistencyEager",
10941102
),
10951103
TorchLibOpInfo(
10961104
"nn.functional.adaptive_avg_pool2d",
@@ -1718,7 +1726,9 @@ def _where_input_wrangler(
17181726
dtypes=(torch.int64,),
17191727
reason="fixme: ORT `LayerNormKernelImpl` not implemented for int64",
17201728
),
1721-
TorchLibOpInfo("logit", core_ops.aten_logit, trace_only=True),
1729+
TorchLibOpInfo(
1730+
"logit", core_ops.aten_logit, trace_only=True, tolerance={torch.float16: (1e-1, 7e-4)}
1731+
),
17221732
TorchLibOpInfo("max_dim", core_ops.aten_max_dim)
17231733
.skip(
17241734
variant_name="reduction_with_dim",
@@ -1869,7 +1879,7 @@ def _where_input_wrangler(
18691879
reason="String padding is not accepted by aten::conv2d",
18701880
),
18711881
TorchLibOpInfo(
1872-
"nn.functional.conv3d",
1882+
"ops.aten.conv3d",
18731883
core_ops.aten_conv3d,
18741884
trace_only=True,
18751885
tolerance={torch.float32: (3.7e-5, 1.8e-4)},
@@ -1974,6 +1984,16 @@ def _where_input_wrangler(
19741984
.skip(
19751985
matcher=lambda sample: sample.kwargs.get("dropout_p") != 0.0,
19761986
reason="dropout is random so the results do not match",
1987+
)
1988+
.xfail(
1989+
dtypes=(torch.float16,),
1990+
reason="fixme: ORT failed. https://github.com/microsoft/onnxruntime/issues/16438",
1991+
test_class_name="TestOutputConsistencyFullGraph",
1992+
)
1993+
.xfail(
1994+
reason="fixme: ORT fails on type mismatch in Add",
1995+
dtypes=(torch.float16,),
1996+
test_class_name="TestOutputConsistencyEager",
19771997
),
19781998
TorchLibOpInfo(
19791999
"ops.aten._scaled_dot_product_flash_attention",
@@ -2000,6 +2020,16 @@ def _where_input_wrangler(
20002020
.skip(
20012021
matcher=lambda sample: sample.kwargs.get("dropout_p") != 0.0,
20022022
reason="dropout is random so the results do not match",
2023+
)
2024+
.xfail(
2025+
dtypes=(torch.float16,),
2026+
reason="fixme: ORT failed. https://github.com/microsoft/onnxruntime/issues/16438",
2027+
test_class_name="TestOutputConsistencyFullGraph",
2028+
)
2029+
.xfail(
2030+
reason="fixme: ORT fails on type mismatch in Add",
2031+
dtypes=(torch.float16,),
2032+
test_class_name="TestOutputConsistencyEager",
20032033
),
20042034
TorchLibOpInfo(
20052035
"nn.functional.upsample_bilinear2d",

0 commit comments

Comments
 (0)