Skip to content

Commit 0e1b0e6

Browse files
committed
Update on "Create env to test with TORCHLIB_EXPERIMENTAL_PREFER_TRACING on | test(torchlib)"
### Changes - Create env to test with TORCHLIB_EXPERIMENTAL_PREFER_TRACING on - Test with Python 3.11 as well [ghstack-poisoned]
2 parents 3bec3f8 + 1d83826 commit 0e1b0e6

File tree

4 files changed

+27
-16
lines changed

4 files changed

+27
-16
lines changed

onnxscript/function_libs/torch_lib/_flags.py

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,3 @@ def _load_boolean_flag(
3939
"TORCHLIB_EXPERIMENTAL_INITIALIZERS_AS_INPUTS",
4040
this_will="make initializers as inputs to the model graph",
4141
)
42-
EXPERIMENTAL_PREFER_TRACING: bool = _load_boolean_flag(
43-
"TORCHLIB_EXPERIMENTAL_PREFER_TRACING",
44-
this_will="trace all traceable functions to fold if branches and collapse constant expressions",
45-
)

onnxscript/function_libs/torch_lib/registration.py

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -99,7 +99,6 @@ def torch_op(
9999
trace_only: bool = False,
100100
private: bool = False,
101101
complex: bool = False,
102-
traceable: bool = False,
103102
) -> Callable[[FunctionType], onnxscript.OnnxFunction | onnxscript.values.TracedOnnxFunction]:
104103
"""Register a torch op.
105104
@@ -113,7 +112,6 @@ def torch_op(
113112
private: Whether the function is private (not directly exposed). It should
114113
be true for all functions with names starting with "_".
115114
complex: Whether the function expects complex-valued inputs.
116-
traceable: Whether the function can be traced.
117115
"""
118116
if registry is None:
119117
registry = default_registry
@@ -130,7 +128,6 @@ def wrapper(
130128
else:
131129
assert isinstance(func, FunctionType)
132130
processed_func = onnxscript.script(opset=custom_opset)(func)
133-
processed_func.experimental_traceable = traceable
134131

135132
assert registry is not None
136133
for name_ in _check_and_normalize_names(name):

onnxscript/tests/function_libs/torch_lib/ops_test_data.py

Lines changed: 27 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -501,7 +501,11 @@ def _where_input_wrangler(
501501
TorchLibOpInfo("acosh", core_ops.aten_acosh),
502502
TorchLibOpInfo("add", core_ops.aten_add, tolerance={torch.float16: (1e-3, 1e-3)}),
503503
TorchLibOpInfo("add", core_ops.aten_add_complex, complex=True, trace_only=True),
504-
TorchLibOpInfo("addbmm", core_ops.aten_addbmm, tolerance={torch.float32: (2e-5, 2e-5)}),
504+
TorchLibOpInfo(
505+
"addbmm",
506+
core_ops.aten_addbmm,
507+
tolerance={torch.float32: (2e-5, 2e-5), torch.float16: (2e-1, 2e-2)},
508+
),
505509
TorchLibOpInfo("addcdiv", core_ops.aten_addcdiv),
506510
TorchLibOpInfo("addcmul", core_ops.aten_addcmul, tolerance={torch.float16: (4e-3, 3e-3)}),
507511
TorchLibOpInfo("addmm", core_ops.aten_addmm)
@@ -522,7 +526,7 @@ def _where_input_wrangler(
522526
dtypes=(torch.int16, torch.int32, torch.int64),
523527
reason="ONNX Runtime does not support int inputs to Gemm",
524528
),
525-
TorchLibOpInfo("addmv", core_ops.aten_addmv),
529+
TorchLibOpInfo("addmv", core_ops.aten_addmv, tolerance={torch.float16: (1e-3, 1e-2)}),
526530
TorchLibOpInfo(
527531
"addr",
528532
core_ops.aten_addr,
@@ -640,7 +644,7 @@ def _where_input_wrangler(
640644
"https://github.com/microsoft/onnxscript/issues/1007"
641645
),
642646
),
643-
TorchLibOpInfo("baddbmm", core_ops.aten_baddbmm),
647+
TorchLibOpInfo("baddbmm", core_ops.aten_baddbmm, tolerance={torch.float16: (1e-3, 1e-2)}),
644648
TorchLibOpInfo("bernoulli", core_ops.aten_bernoulli, nondeterministic=True),
645649
TorchLibOpInfo(
646650
# This string is a unique ID. In extra_opinfo.py, we
@@ -845,6 +849,12 @@ def _where_input_wrangler(
845849
dtypes=(torch.int64, torch.int32),
846850
reason="fixme: Results do not match with PyTorch. https://github.com/microsoft/onnxscript/issues/854",
847851
)
852+
.xfail(
853+
variant_name="tensor_overload",
854+
dtypes=(torch.int64, torch.int32, torch.float16),
855+
reason="fixme: Results do not match with PyTorch. https://github.com/microsoft/onnxscript/issues/854",
856+
enabled_if=not version_utils.torch_older_than("2.1"),
857+
)
848858
.xfail(
849859
dtypes=(torch.float16,),
850860
reason="op 'Range' doesn't support float16.",
@@ -884,7 +894,7 @@ def _where_input_wrangler(
884894
"matmul",
885895
core_ops.aten_matmul,
886896
# Windows requires a more relaxed tolerance
887-
tolerance={torch.float32: (2e-5, 2e-5)},
897+
tolerance={torch.float32: (2e-5, 2e-5), torch.float16: (2e-3, 2e-2)},
888898
).skip(
889899
matcher=lambda sample: torch.numel(sample.input) == 0,
890900
reason="values of matmul of [m, 0] and [0, n] matrices are undefined",
@@ -1700,7 +1710,12 @@ def _where_input_wrangler(
17001710
variant_name="empty_strides",
17011711
reason="fixme: 'shape' do not match: torch.Size([2, 3, 4, 3]) != torch.Size([2, 3, 4, 2]). https://github.com/microsoft/onnxscript/issues/975",
17021712
),
1703-
TorchLibOpInfo("native_batch_norm", core_ops.aten_native_batch_norm, trace_only=True),
1713+
TorchLibOpInfo(
1714+
"native_batch_norm",
1715+
core_ops.aten_native_batch_norm,
1716+
trace_only=True,
1717+
tolerance={torch.float16: (9e-3, 7e-4)},
1718+
),
17041719
TorchLibOpInfo(
17051720
"ops.aten._native_batch_norm_legit", core_ops.aten_native_batch_norm, trace_only=True
17061721
),
@@ -1719,9 +1734,11 @@ def _where_input_wrangler(
17191734
"ops.aten.native_group_norm",
17201735
core_ops.aten_native_group_norm,
17211736
trace_only=True,
1737+
tolerance={torch.float16: (1e-2, 7e-3)},
17221738
).xfail(
17231739
dtypes=(torch.float16,),
17241740
reason="fixme: 'GroupNormKernelImpl' not implemented for 'Half' in nightly and weekly",
1741+
enabled_if=version_utils.torch_older_than("2.2"),
17251742
),
17261743
TorchLibOpInfo(
17271744
"native_layer_norm",
@@ -1809,7 +1826,11 @@ def _where_input_wrangler(
18091826
matcher=lambda sample: len(sample.args) != 1,
18101827
reason="this overload is implemented for bias=None",
18111828
),
1812-
TorchLibOpInfo("nn.functional.linear_bias", nn_ops.aten_linear_bias).skip(
1829+
TorchLibOpInfo(
1830+
"nn.functional.linear_bias",
1831+
nn_ops.aten_linear_bias,
1832+
tolerance={torch.float16: (2e-1, 4e-4)},
1833+
).skip(
18131834
# input: input, args: weight, bias; so len(args) == 2 means bias is provided
18141835
matcher=lambda sample: len(sample.args) != 2,
18151836
reason="this overload is implemented for bias!=None",

onnxscript/values.py

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -479,9 +479,6 @@ def __init__(
479479
self._param_schemas: Optional[tuple[ParamSchema, ...]] = None
480480
self._op_schema: Optional[onnx.defs.OpSchema] = None
481481

482-
# Experimental fields
483-
self.experimental_traceable = False
484-
485482
@property
486483
@deprecation.deprecated(
487484
since="0.1",

0 commit comments

Comments
 (0)