Skip to content

Commit 26419ef

Browse files
feat(ateblib):Add Ops (all, isfinite, split_with_sizes) (#513)
Co-authored-by: Justin Chu <justinchuby@users.noreply.github.com>
1 parent 946d12e commit 26419ef

File tree

2 files changed

+50
-5
lines changed

2 files changed

+50
-5
lines changed

onnxscript/function_libs/torch_aten/ops/core.py

Lines changed: 34 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -176,10 +176,35 @@ def aten_align_to(self: TensorType, names: Sequence[str]) -> TensorType:
176176
raise NotImplementedError()
177177

178178

179-
def aten_all(self: TensorType) -> TensorType:
179+
@torch_op("aten::all")
180+
def aten_all(self: TTensor) -> BOOL:
180181
"""all(Tensor self) -> Tensor"""
181182

182-
raise NotImplementedError()
183+
if op.Size(op.Shape(self)) == 0:
184+
result = op.Cast(self, to=BOOL.dtype)
185+
else:
186+
self_bool = op.Cast(self, to=BOOL.dtype)
187+
self_int = op.Cast(self_bool, to=INT64.dtype)
188+
result_int = op.ReduceMin(self_int, keepdims=0)
189+
result = op.Cast(result_int, to=BOOL.dtype)
190+
191+
return result
192+
193+
194+
@torch_op("aten::all", overload=True)
195+
def aten_all_dim(self: TTensor, dim: int, keepdim: bool = False) -> BOOL:
196+
"""all(Tensor self) -> Tensor"""
197+
198+
if op.Size(op.Shape(self)) == 0:
199+
result = op.Cast(self, to=BOOL.dtype)
200+
else:
201+
self_bool = op.Cast(self, to=BOOL.dtype)
202+
self_int = op.Cast(self_bool, to=INT64.dtype)
203+
dims = op.Reshape(dim, op.Constant(value_ints=[-1]))
204+
result_int = op.ReduceMin(self_int, dims, keepdims=keepdim)
205+
result = op.Cast(result_int, to=BOOL.dtype)
206+
207+
return result
183208

184209

185210
def aten_allclose(
@@ -2899,10 +2924,13 @@ def aten_isclose(
28992924
raise NotImplementedError()
29002925

29012926

2927+
@torch_op("aten::isfinite")
29022928
def aten_isfinite(self: TensorType) -> TensorType:
29032929
"""isfinite(Tensor self) -> Tensor"""
29042930

2905-
raise NotImplementedError()
2931+
not_inf = op.Not(op.IsInf(self))
2932+
not_nan = op.Not(op.IsNaN(self)) # TODO: The test case doesnt cover this condition
2933+
return op.And(not_inf, not_nan)
29062934

29072935

29082936
@torch_op("aten::isinf")
@@ -5187,10 +5215,11 @@ def aten_split_copy(self: TensorType, split_size: INT64, dim: int = 0) -> Tensor
51875215
raise NotImplementedError()
51885216

51895217

5190-
def aten_split_with_sizes(self: TensorType, split_sizes: INT64, dim: int = 0) -> TensorType:
5218+
@torch_op("aten::split_with_sizes")
5219+
def aten_split_with_sizes(self: TTensor, split_sizes: INT64, dim: int = 0) -> TTensor:
51915220
"""split_with_sizes(Tensor(a -> *) self, SymInt[] split_sizes, int dim=0) -> Tensor(a)[]"""
51925221

5193-
raise NotImplementedError()
5222+
return op.SplitToSequence(self, split_sizes, axis=dim)
51945223

51955224

51965225
def aten_split_with_sizes_copy(

onnxscript/tests/function_libs/torch_aten/ops_correctness_test.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -283,6 +283,8 @@ def _where_input_wrangler(
283283
Callable[[list[Any], dict[str, Any]], tuple[list[Any], dict[str, Any]]],
284284
],
285285
] = {
286+
"all_dim": core_ops.aten_all_dim,
287+
"all": core_ops.aten_all,
286288
"abs": core_ops.aten_abs,
287289
"acos": core_ops.aten_acos,
288290
"acosh": core_ops.aten_acosh,
@@ -325,6 +327,7 @@ def _where_input_wrangler(
325327
"full_like": core_ops.aten_full_like,
326328
"ge": core_ops.aten_ge,
327329
"gt": core_ops.aten_gt,
330+
"isfinite": core_ops.aten_isfinite,
328331
"isinf": core_ops.aten_isinf,
329332
"log": core_ops.aten_log,
330333
"le": core_ops.aten_le,
@@ -385,6 +388,7 @@ def _where_input_wrangler(
385388
"sin": core_ops.aten_sin,
386389
"sinh": core_ops.aten_sinh,
387390
"softmax": special_ops.aten_special_softmax,
391+
"split_with_sizes": core_ops.aten_split_with_sizes,
388392
"split": core_ops.aten_split,
389393
"sqrt": core_ops.aten_sqrt,
390394
"stack": core_ops.aten_stack,
@@ -478,6 +482,16 @@ def _where_input_wrangler(
478482

479483

480484
SKIP_SUBTESTS: tuple[DecorateMeta, ...] = (
485+
skip(
486+
"all",
487+
matcher=lambda sample: not (len(sample.kwargs) == 0),
488+
reason="this Aten overload only support one tensor as input by design",
489+
),
490+
skip(
491+
"all_dim",
492+
matcher=lambda sample: not (len(sample.kwargs) > 0),
493+
reason="this Aten overload only support one tensor as input and {dim,keepdim} as kwargs by design",
494+
),
481495
skip(
482496
"arange",
483497
matcher=lambda sample: len(sample.args) != 0,
@@ -594,6 +608,8 @@ def _where_input_wrangler(
594608
),
595609
)
596610

611+
duplicate_opinfo(OPS_DB, "all", ("all_dim",))
612+
597613
duplicate_opinfo(
598614
OPS_DB,
599615
"arange",

0 commit comments

Comments
 (0)