Skip to content

Commit b0e82c1

Browse files
committed
Merge branch 'main' into xiaowu/addOp(index_put)
2 parents 3e751a9 + bff0b52 commit b0e82c1

File tree

2 files changed

+29
-11
lines changed

2 files changed

+29
-11
lines changed

onnxscript/function_libs/torch_aten/ops/core.py

Lines changed: 27 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313

1414
from typing import Any, Optional, Sequence, Tuple, Union
1515

16-
from onnxscript import BOOL, DOUBLE, FLOAT, INT16, INT32, INT64
16+
from onnxscript import BOOL, DOUBLE, FLOAT, INT8, INT16, INT32, INT64
1717
from onnxscript.function_libs.torch_aten.registration import torch_op
1818
from onnxscript.function_libs.torch_aten.tensor_typing import (
1919
IntType,
@@ -206,16 +206,26 @@ def aten_all_dim(self: TTensor, dim: int, keepdim: bool = False) -> BOOL:
206206
return result
207207

208208

209+
@torch_op("aten::allclose")
209210
def aten_allclose(
210-
self: TensorType,
211-
other: TensorType,
211+
self: TReal,
212+
other: TReal,
212213
rtol: float = 1e-05,
213214
atol: float = 1e-08,
214-
equal_nan: bool = False,
215-
) -> bool:
215+
equal_nan: bool = False, # pylint: disable=unused-argument
216+
) -> BOOL:
216217
"""allclose(Tensor self, Tensor other, float rtol=1e-05, float atol=1e-08, bool equal_nan=False) -> bool"""
217218

218-
raise NotImplementedError()
219+
# FIXME: check equal_nan when self and other are all NaN
220+
# |input - other| <= atol + rtol x |other|
221+
left_part = op.Abs(op.Sub(self, other))
222+
right_part = op.Add(atol, op.Mul(rtol, op.Abs(other)))
223+
is_close = op.LessOrEqual(left_part, right_part)
224+
is_close_int = op.Cast(is_close, to=INT8.dtype)
225+
226+
# If min is 0, some elements are not close -> allclose is False
227+
# If min is 1, all elements are close -> allclose is True
228+
return op.Cast(op.ReduceMin(is_close_int, keepdims=0), to=BOOL.dtype)
219229

220230

221231
def aten_alpha_dropout(input: TensorType, p: float, train: bool) -> TensorType:
@@ -2935,16 +2945,22 @@ def aten_is_vulkan_available() -> bool:
29352945
raise NotImplementedError()
29362946

29372947

2948+
@torch_op("aten::isclose")
29382949
def aten_isclose(
2939-
self: TensorType,
2940-
other: TensorType,
2950+
self: TReal,
2951+
other: TReal,
29412952
rtol: float = 1e-05,
29422953
atol: float = 1e-08,
2943-
equal_nan: bool = False,
2944-
) -> TensorType:
2954+
equal_nan: bool = False, # pylint: disable=unused-argument
2955+
) -> BOOL:
29452956
"""isclose(Tensor self, Tensor other, float rtol=1e-05, float atol=1e-08, bool equal_nan=False) -> Tensor"""
29462957

2947-
raise NotImplementedError()
2958+
# FIXME: check equal_nan when self and other are all NaN
2959+
# |input - other| <= atol + rtol x |other|
2960+
left_part = op.Abs(op.Sub(self, other))
2961+
right_part = op.Add(atol, op.Mul(rtol, op.Abs(other)))
2962+
result = op.LessOrEqual(left_part, right_part)
2963+
return result
29482964

29492965

29502966
@torch_op("aten::isfinite")

onnxscript/tests/function_libs/torch_aten/ops_correctness_test.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -358,6 +358,7 @@ def _where_input_wrangler(
358358
],
359359
] = {
360360
"all_dim": core_ops.aten_all_dim,
361+
"allclose": core_ops.aten_allclose,
361362
"all": core_ops.aten_all,
362363
"abs": core_ops.aten_abs,
363364
"acos": core_ops.aten_acos,
@@ -405,6 +406,7 @@ def _where_input_wrangler(
405406
"gt": core_ops.aten_gt,
406407
"index_put_bool": core_ops.aten_index_put_bool,
407408
"index_put": core_ops.aten_index_put,
409+
"isclose": core_ops.aten_isclose,
408410
"isfinite": core_ops.aten_isfinite,
409411
"isinf": core_ops.aten_isinf,
410412
"log": core_ops.aten_log,

0 commit comments

Comments
 (0)