Skip to content

Commit 0e24ee9

Browse files
committed
Update base for Update on "Mark some functions as traceable | feat(torchlib)"
As an effort described in #1095, this PR marks functions with if branches as traceable. [ghstack-poisoned]
1 parent 1866cb8 commit 0e24ee9

File tree

2 files changed

+26
-10
lines changed

2 files changed

+26
-10
lines changed

onnxscript/function_libs/torch_lib/ops/core.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -365,7 +365,9 @@ def aten_all_dims(self: TTensor, dim: Sequence[int] = (), keepdim: bool = False)
365365
if not dim:
366366
return aten_all_dims_no_dim(self, keepdim)
367367
for d in dim:
368-
self = aten_all_dim(self, d, keepdim)
368+
self = aten_all_dim(self, d, keepdim=True)
369+
if not keepdim:
370+
self = op.Squeeze(self, list(dim))
369371
return self
370372

371373

@@ -488,7 +490,9 @@ def aten_any_dims(self: TTensor, dim: Sequence[int] = (), keepdim: bool = False)
488490
if not dim:
489491
return aten_any_dims_no_dim(self, keepdim)
490492
for d in dim:
491-
self = aten_any_dim(self, d, keepdim)
493+
self = aten_any_dim(self, d, keepdim=True)
494+
if not keepdim:
495+
self = op.Squeeze(self, list(dim))
492496
return self
493497

494498

onnxscript/tests/function_libs/torch_lib/ops_test_data.py

Lines changed: 20 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -483,9 +483,14 @@ def _where_input_wrangler(
483483
reason="PyTorch does not implement _softmax for float16 on CPU",
484484
dtypes=(torch.float16,),
485485
),
486-
TorchLibOpInfo("all_dim", core_ops.aten_all_dim).xfail(
487-
matcher=lambda sample: not (len(sample.kwargs) > 0),
488-
reason="this Aten overload only support one tensor as input and {dim,keepdim} as kwargs by design",
486+
TorchLibOpInfo("all_dim", core_ops.aten_all_dim).skip(
487+
matcher=lambda sample: not (len(sample.kwargs) > 0)
488+
or isinstance(sample.kwargs.get("dim"), tuple),
489+
reason="this Aten overload only support one tensor as input and {dim,keepdim} as kwargs by design. dim must be an integer",
490+
),
491+
TorchLibOpInfo("all_dims", core_ops.aten_all_dims, trace_only=True).skip(
492+
matcher=lambda sample: not isinstance(sample.kwargs.get("dim"), tuple),
493+
reason="this overload requires dim to be a tuple",
489494
),
490495
TorchLibOpInfo("allclose", core_ops.aten_allclose),
491496
TorchLibOpInfo(
@@ -561,8 +566,13 @@ def _where_input_wrangler(
561566
"any_dim",
562567
core_ops.aten_any_dim,
563568
).skip(
564-
matcher=lambda sample: not (len(sample.kwargs) > 0),
565-
reason="this Aten overload only support one tensor as input and {dim,keepdim} as kwargs by design",
569+
matcher=lambda sample: not (len(sample.kwargs) > 0)
570+
or isinstance(sample.kwargs.get("dim"), tuple),
571+
reason="this Aten overload only support one tensor as input and {dim,keepdim} as kwargs by design. dim must be an integer",
572+
),
573+
TorchLibOpInfo("any_dims", core_ops.aten_any_dims, trace_only=True).skip(
574+
matcher=lambda sample: not isinstance(sample.kwargs.get("dim"), tuple),
575+
reason="this overload requires dim to be a tuple",
566576
),
567577
TorchLibOpInfo("asin", core_ops.aten_asin),
568578
TorchLibOpInfo("asinh", core_ops.aten_asinh),
@@ -881,7 +891,9 @@ def _where_input_wrangler(
881891
TorchLibOpInfo("log2", core_ops.aten_log2),
882892
TorchLibOpInfo("logaddexp", core_ops.aten_logaddexp),
883893
TorchLibOpInfo("logaddexp2", core_ops.aten_logaddexp2),
884-
TorchLibOpInfo("logcumsumexp", core_ops.aten_logcumsumexp),
894+
TorchLibOpInfo(
895+
"logcumsumexp", core_ops.aten_logcumsumexp, tolerance={torch.float16: (1e-2, 1e-1)}
896+
),
885897
TorchLibOpInfo("logdet", core_ops.aten_logdet),
886898
TorchLibOpInfo("logsumexp", core_ops.aten_logsumexp),
887899
TorchLibOpInfo("lt", core_ops.aten_lt),
@@ -2080,8 +2092,8 @@ def _where_input_wrangler(
20802092
TorchLibOpInfo("zeros_like", core_ops.aten_zeros_like, trace_only=True),
20812093
)
20822094

2083-
ops_test_common.duplicate_opinfo(OPS_DB, "all", ("all_dim",))
2084-
ops_test_common.duplicate_opinfo(OPS_DB, "any", ("any_dim",))
2095+
ops_test_common.duplicate_opinfo(OPS_DB, "all", ("all_dim", "all_dims"))
2096+
ops_test_common.duplicate_opinfo(OPS_DB, "any", ("any_dim", "any_dims"))
20852097
ops_test_common.duplicate_opinfo(OPS_DB, "arange", ("arange_start", "arange_start_step"))
20862098
ops_test_common.duplicate_opinfo(OPS_DB, "argmax", ("argmax_dim",))
20872099
ops_test_common.duplicate_opinfo(OPS_DB, "argmin", ("argmin_dim",))

0 commit comments

Comments
 (0)