Skip to content

Commit 8aba4df

Browse files
committed
Update base for Update on "Implement the experimental evaluator for folding branches and castlikes | feat(torchlib)"
As an effort described in #1095, this PR - Implements the experimental evaluator for folding branches and castlikes so that they are eagerly evaluated when possible. - Updates implementation for `addr` for it to be traceable. - Conditionally enabled previously xfailed tests. Set `TORCHLIB_EXPERIMENTAL_PREFER_TRACING=1` and tested in CI. E.g. clamp_min now becomes ``` < ir_version: 8, opset_import: ["" : 18, "pkg.onnxscript.torch_lib.common" : 1], producer_name: "pytorch", producer_version: "2.2.0" > main_graph (int32[5,10,5] input_0, int32[10,5] input_1) => (int32[5,10,5] _val_9) <int64 _val_2, int64[2] _val_3, int64 _val_4, int64 _val_5, bool _val_6, int64 _val_7, bool _val_8> { _val_2 = Size (input_0) _val_3 = Shape <start: int = 0> (input_1) _val_4 = Size (_val_3) _val_5 = Constant <value: tensor = int64 {0}> () _val_6 = Equal (_val_2, _val_5) _val_7 = Constant <value: tensor = int64 {0}> () _val_8 = Equal (_val_4, _val_7) _val_9 = Max (input_0, input_1) } < domain: "pkg.onnxscript.torch_lib.common", opset_import: ["" : 18] > Rank (input) => (return_val) { tmp = Shape (input) return_val = Size (tmp) } < domain: "pkg.onnxscript.torch_lib.common", opset_import: ["" : 18] > IsScalar (input) => (return_val) { tmp = Shape (input) tmp_0 = Size (tmp) tmp_1 = Constant <value_int: int = 0> () return_val = Equal (tmp_0, tmp_1) } ``` [ghstack-poisoned]
1 parent 275e4c5 commit 8aba4df

File tree

2 files changed

+26
-10
lines changed

2 files changed

+26
-10
lines changed

onnxscript/function_libs/torch_lib/ops/core.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -365,7 +365,9 @@ def aten_all_dims(self: TTensor, dim: Sequence[int] = (), keepdim: bool = False)
365365
if not dim:
366366
return aten_all_dims_no_dim(self, keepdim)
367367
for d in dim:
368-
self = aten_all_dim(self, d, keepdim)
368+
self = aten_all_dim(self, d, keepdim=True)
369+
if not keepdim:
370+
self = op.Squeeze(self, list(dim))
369371
return self
370372

371373

@@ -488,7 +490,9 @@ def aten_any_dims(self: TTensor, dim: Sequence[int] = (), keepdim: bool = False)
488490
if not dim:
489491
return aten_any_dims_no_dim(self, keepdim)
490492
for d in dim:
491-
self = aten_any_dim(self, d, keepdim)
493+
self = aten_any_dim(self, d, keepdim=True)
494+
if not keepdim:
495+
self = op.Squeeze(self, list(dim))
492496
return self
493497

494498

onnxscript/tests/function_libs/torch_lib/ops_test_data.py

Lines changed: 20 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -483,9 +483,14 @@ def _where_input_wrangler(
483483
reason="PyTorch does not implement _softmax for float16 on CPU",
484484
dtypes=(torch.float16,),
485485
),
486-
TorchLibOpInfo("all_dim", core_ops.aten_all_dim).xfail(
487-
matcher=lambda sample: not (len(sample.kwargs) > 0),
488-
reason="this Aten overload only support one tensor as input and {dim,keepdim} as kwargs by design",
486+
TorchLibOpInfo("all_dim", core_ops.aten_all_dim).skip(
487+
matcher=lambda sample: not (len(sample.kwargs) > 0)
488+
or isinstance(sample.kwargs.get("dim"), tuple),
489+
reason="this Aten overload only support one tensor as input and {dim,keepdim} as kwargs by design. dim must be an integer",
490+
),
491+
TorchLibOpInfo("all_dims", core_ops.aten_all_dims, trace_only=True).skip(
492+
matcher=lambda sample: not isinstance(sample.kwargs.get("dim"), tuple),
493+
reason="this overload requires dim to be a tuple",
489494
),
490495
TorchLibOpInfo("allclose", core_ops.aten_allclose),
491496
TorchLibOpInfo(
@@ -561,8 +566,13 @@ def _where_input_wrangler(
561566
"any_dim",
562567
core_ops.aten_any_dim,
563568
).skip(
564-
matcher=lambda sample: not (len(sample.kwargs) > 0),
565-
reason="this Aten overload only support one tensor as input and {dim,keepdim} as kwargs by design",
569+
matcher=lambda sample: not (len(sample.kwargs) > 0)
570+
or isinstance(sample.kwargs.get("dim"), tuple),
571+
reason="this Aten overload only support one tensor as input and {dim,keepdim} as kwargs by design. dim must be an integer",
572+
),
573+
TorchLibOpInfo("any_dims", core_ops.aten_any_dims, trace_only=True).skip(
574+
matcher=lambda sample: not isinstance(sample.kwargs.get("dim"), tuple),
575+
reason="this overload requires dim to be a tuple",
566576
),
567577
TorchLibOpInfo("asin", core_ops.aten_asin),
568578
TorchLibOpInfo("asinh", core_ops.aten_asinh),
@@ -881,7 +891,9 @@ def _where_input_wrangler(
881891
TorchLibOpInfo("log2", core_ops.aten_log2),
882892
TorchLibOpInfo("logaddexp", core_ops.aten_logaddexp),
883893
TorchLibOpInfo("logaddexp2", core_ops.aten_logaddexp2),
884-
TorchLibOpInfo("logcumsumexp", core_ops.aten_logcumsumexp),
894+
TorchLibOpInfo(
895+
"logcumsumexp", core_ops.aten_logcumsumexp, tolerance={torch.float16: (1e-2, 1e-1)}
896+
),
885897
TorchLibOpInfo("logdet", core_ops.aten_logdet),
886898
TorchLibOpInfo("logsumexp", core_ops.aten_logsumexp),
887899
TorchLibOpInfo("lt", core_ops.aten_lt),
@@ -2080,8 +2092,8 @@ def _where_input_wrangler(
20802092
TorchLibOpInfo("zeros_like", core_ops.aten_zeros_like, trace_only=True),
20812093
)
20822094

2083-
ops_test_common.duplicate_opinfo(OPS_DB, "all", ("all_dim",))
2084-
ops_test_common.duplicate_opinfo(OPS_DB, "any", ("any_dim",))
2095+
ops_test_common.duplicate_opinfo(OPS_DB, "all", ("all_dim", "all_dims"))
2096+
ops_test_common.duplicate_opinfo(OPS_DB, "any", ("any_dim", "any_dims"))
20852097
ops_test_common.duplicate_opinfo(OPS_DB, "arange", ("arange_start", "arange_start_step"))
20862098
ops_test_common.duplicate_opinfo(OPS_DB, "argmax", ("argmax_dim",))
20872099
ops_test_common.duplicate_opinfo(OPS_DB, "argmin", ("argmin_dim",))

0 commit comments

Comments
 (0)