Skip to content

Commit 94a0f5c

Browse files
authored
Merge branch 'main' into xiaowu/addOps(0216)
2 parents 58ba823 + 10395c8 commit 94a0f5c

File tree

2 files changed

+31
-4
lines changed

2 files changed

+31
-4
lines changed

onnxscript/function_libs/torch_aten/ops/core.py

Lines changed: 29 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -269,10 +269,32 @@ def aten_angle(self: TensorType) -> TensorType:
269269
raise NotImplementedError()
270270

271271

272-
def aten_any(self: TensorType) -> TensorType:
272+
@torch_op("aten::any", trace_only=True)
273+
def aten_any(self: TTensor, dim: Optional[int] = None, keepdim: bool = True) -> BOOL:
273274
"""any(Tensor self) -> Tensor"""
274275

275-
raise NotImplementedError()
276+
minus_1 = op.Constant(value_ints=[-1])
277+
self_rank = op.Size(op.Shape(self))
278+
if self_rank == 0:
279+
self = op.Reshape(self, minus_1)
280+
281+
zero = op.Constant(value_float=0.0)
282+
result = op.Not(op.Equal(self, zero))
283+
# because op.ReduceMax() cannot calculate BOOL value
284+
result_float = op.Cast(result, to=FLOAT.dtype)
285+
286+
if op.OptionalHasElement(dim):
287+
dim = op.Reshape(dim, minus_1)
288+
dims = op.Cast(dim, to=INT64.dtype)
289+
result_max = op.ReduceMax(result_float, dims, keepdims=keepdim, noop_with_empty_axes=0)
290+
else:
291+
result_max = op.ReduceMax(result_float, keepdims=0, noop_with_empty_axes=0)
292+
293+
result = op.Greater(result_max, zero)
294+
if self_rank == 0:
295+
result = op.Squeeze(result)
296+
297+
return result
276298

277299

278300
def _range_supported(dtype: int) -> bool:
@@ -2044,10 +2066,13 @@ def aten_expand(self: TTensor, size: TInt) -> TTensor:
20442066
return op.Expand(self, size)
20452067

20462068

2047-
def aten_expand_as(self: TensorType, other: TensorType) -> TensorType:
2069+
@torch_op("aten::expand_as")
2070+
def aten_expand_as(self: TTensor, other: TTensor) -> TTensor:
20482071
"""expand_as(Tensor(a) self, Tensor other) -> Tensor(a)"""
20492072

2050-
raise NotImplementedError()
2073+
shape = op.Shape(other)
2074+
result = op.Expand(self, shape)
2075+
return result
20512076

20522077

20532078
def aten_expand_copy(self: TensorType, size: INT64, implicit: bool = False) -> TensorType:

onnxscript/tests/function_libs/torch_aten/ops_correctness_test.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -314,6 +314,7 @@ def _where_input_wrangler(
314314
"exp": core_ops.aten_exp,
315315
"exp2": core_ops.aten_exp2,
316316
"expand": core_ops.aten_expand,
317+
"expand_as": core_ops.aten_expand_as,
317318
"erf": core_ops.aten_erf,
318319
"fill": core_ops.aten_fill,
319320
"fmod": core_ops.aten_fmod,
@@ -398,6 +399,7 @@ def _where_input_wrangler(
398399
] = {
399400
"amax": core_ops.aten_amax,
400401
"amin": core_ops.aten_amin,
402+
"any": core_ops.aten_any, # TODO: add more testcase which element is [0.0, 0.1, -0.1, 0.0] etc.
401403
"arange_start_step": core_ops.aten_arange_start_step,
402404
"arange_start": core_ops.aten_arange_start,
403405
"arange": core_ops.aten_arange,

0 commit comments

Comments
 (0)