|
13 | 13 |
|
14 | 14 | from typing import Any, Optional, Sequence, Tuple, Union |
15 | 15 |
|
16 | | -from onnxscript import BOOL, DOUBLE, FLOAT, INT16, INT32, INT64 |
| 16 | +from onnxscript import BOOL, DOUBLE, FLOAT, INT8, INT16, INT32, INT64 |
17 | 17 | from onnxscript.function_libs.torch_aten.registration import torch_op |
18 | 18 | from onnxscript.function_libs.torch_aten.tensor_typing import ( |
19 | 19 | IntType, |
@@ -206,16 +206,26 @@ def aten_all_dim(self: TTensor, dim: int, keepdim: bool = False) -> BOOL: |
206 | 206 | return result |
207 | 207 |
|
208 | 208 |
|
| 209 | +@torch_op("aten::allclose") |
209 | 210 | def aten_allclose( |
210 | | - self: TensorType, |
211 | | - other: TensorType, |
| 211 | + self: TReal, |
| 212 | + other: TReal, |
212 | 213 | rtol: float = 1e-05, |
213 | 214 | atol: float = 1e-08, |
214 | | - equal_nan: bool = False, |
215 | | -) -> bool: |
| 215 | + equal_nan: bool = False, # pylint: disable=unused-argument |
| 216 | +) -> BOOL: |
216 | 217 | """allclose(Tensor self, Tensor other, float rtol=1e-05, float atol=1e-08, bool equal_nan=False) -> bool""" |
217 | 218 |
|
218 | | - raise NotImplementedError() |
| 219 | + # FIXME: check equal_nan when self and other are all NaN |
| 220 | + # |input - other| <= atol + rtol x |other| |
| 221 | + left_part = op.Abs(op.Sub(self, other)) |
| 222 | + right_part = op.Add(atol, op.Mul(rtol, op.Abs(other))) |
| 223 | + is_close = op.LessOrEqual(left_part, right_part) |
| 224 | + is_close_int = op.Cast(is_close, to=INT8.dtype) |
| 225 | + |
| 226 | + # If min is 0, some elements are not close -> allclose is False |
| 227 | + # If min is 1, all elements are close -> allclose is True |
| 228 | + return op.Cast(op.ReduceMin(is_close_int, keepdims=0), to=BOOL.dtype) |
219 | 229 |
|
220 | 230 |
|
221 | 231 | def aten_alpha_dropout(input: TensorType, p: float, train: bool) -> TensorType: |
@@ -2935,16 +2945,22 @@ def aten_is_vulkan_available() -> bool: |
2935 | 2945 | raise NotImplementedError() |
2936 | 2946 |
|
2937 | 2947 |
|
| 2948 | +@torch_op("aten::isclose") |
2938 | 2949 | def aten_isclose( |
2939 | | - self: TensorType, |
2940 | | - other: TensorType, |
| 2950 | + self: TReal, |
| 2951 | + other: TReal, |
2941 | 2952 | rtol: float = 1e-05, |
2942 | 2953 | atol: float = 1e-08, |
2943 | | - equal_nan: bool = False, |
2944 | | -) -> TensorType: |
| 2954 | + equal_nan: bool = False, # pylint: disable=unused-argument |
| 2955 | +) -> BOOL: |
2945 | 2956 | """isclose(Tensor self, Tensor other, float rtol=1e-05, float atol=1e-08, bool equal_nan=False) -> Tensor""" |
2946 | 2957 |
|
2947 | | - raise NotImplementedError() |
| 2958 | + # FIXME: check equal_nan when self and other are all NaN |
| 2959 | + # |input - other| <= atol + rtol x |other| |
| 2960 | + left_part = op.Abs(op.Sub(self, other)) |
| 2961 | + right_part = op.Add(atol, op.Mul(rtol, op.Abs(other))) |
| 2962 | + result = op.LessOrEqual(left_part, right_part) |
| 2963 | + return result |
2948 | 2964 |
|
2949 | 2965 |
|
2950 | 2966 | @torch_op("aten::isfinite") |
|
0 commit comments