diff --git a/onnxscript/function_libs/torch_aten/ops/core.py b/onnxscript/function_libs/torch_aten/ops/core.py index fb2172736a..377e25eb86 100644 --- a/onnxscript/function_libs/torch_aten/ops/core.py +++ b/onnxscript/function_libs/torch_aten/ops/core.py @@ -739,24 +739,55 @@ def aten_chunk(self: TensorType, chunks: int, dim: int = 0) -> TensorType: raise NotImplementedError() -def aten_clamp( - self: TensorType, min: Optional[float] = None, max: Optional[float] = None -) -> TensorType: +def aten_clamp(self: TensorType, min_=None, max_=None) -> TensorType: # clamp(Tensor self, Scalar? min=None, Scalar? max=None) -> Tensor - raise NotImplementedError() + # TODO(justinchuby): Handle integer inputs + # FIXME(justinchuby): Enable test for this after None values are supported + # TODO(justinchuby): If min is greater than max torch.clamp(..., min, max) + # sets all elements in input to the value of max. + if op.OptionalHasElement(min_): + min_ = op.OptionalGetElement(min_) + min_clamp = op.CastLike(min_, self) # type: ignore[arg-type] + else: + min_clamp = op.Constant(value_float=float("-inf")) + + if op.OptionalHasElement(max_): + max_ = op.OptionalGetElement(max_) + max_clamp = op.CastLike(max_, self) # type: ignore[arg-type] + else: + max_clamp = op.Constant(value_float=float("inf")) + + # Enforce the lower and upper bounds + clamped = op.Max(op.Min(self, max_clamp), min_clamp) # type: ignore[arg-type] + return clamped -def aten_clamp_max(self: TensorType, max: float) -> TensorType: +def aten_clamp_max_scalar(self, max_): # clamp_max(Tensor self, Scalar max) -> Tensor - raise NotImplementedError() + max_ = op.CastLike(max_, self) + return op.Clip(self, None, max_) + + +def aten_clamp_max_tensor(self, max_): + # clamp_max(Tensor self, Scalar max) -> Tensor + + return op.Min(self, max_) -def aten_clamp_min(self: TensorType, min: float) -> TensorType: +def aten_clamp_min_scalar(self, min_): # clamp_min(Tensor self, Scalar min) -> Tensor + # NOTE: min_ is a rank 0 tensor. + # TODO(justinchuby): Specify the type constraints. + min_ = op.CastLike(min_, self) + return op.Clip(self, min_, None) - raise NotImplementedError() + +def aten_clamp_min_tensor(self, min_): + # clamp_min(Tensor self, Tensor min) -> Tensor + # TODO(justinchuby): Specify the type constraints. + return op.Max(self, min_) def aten_clip( @@ -1958,10 +1989,12 @@ def aten_gru_cell( raise NotImplementedError() -def aten_gt(self: TensorType, other: TensorType) -> TensorType: +def aten_gt(self, other): # gt.Tensor(Tensor self, Tensor other) -> Tensor - raise NotImplementedError() + # TODO(justinchuby): Input spec: non bool tensor + # Boolean inputs can be pre-casted by policy + return op.Greater(self, other) def aten_hamming_window(window_length: int) -> TensorType: @@ -2572,10 +2605,12 @@ def aten_lstm_mps_backward( raise NotImplementedError() -def aten_lt(self: TensorType, other: TensorType) -> TensorType: +def aten_lt(self, other): # lt.Tensor(Tensor self, Tensor other) -> Tensor - raise NotImplementedError() + # TODO(justinchuby): Input spec: non bool tensor + # Boolean inputs can be pre-casted by policy + return op.Less(self, other) def aten_lu_solve(self: TensorType, LU_data: TensorType, LU_pivots: TensorType) -> TensorType: @@ -3440,10 +3475,20 @@ def aten_ones(size: INT64) -> TensorType: raise NotImplementedError() -def aten_ones_like(self: TensorType, memory_format: Optional[str] = None) -> TensorType: +def aten_ones_like(self, dtype: int = -1): + """ones_like. + + Note: dtype is an onnx enum. Users should convert torch dtype to onnx dtype + before calling this function. + """ # ones_like(Tensor self, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None, MemoryFormat? memory_format=None) -> Tensor - raise NotImplementedError() + shape = op.Shape(self) + if dtype == -1: + one = op.CastLike(1, self) # type: ignore[arg-type] + else: + one = op.Cast(1, to=dtype) # type: ignore[arg-type] + return op.Expand(one, shape) def aten_or(self: TensorType, other: TensorType) -> TensorType: @@ -3916,10 +3961,19 @@ def aten_renorm(self: TensorType, p: float, dim: int, maxnorm: float) -> TensorT raise NotImplementedError() -def aten_repeat(self: TensorType, repeats: INT64) -> TensorType: +def aten_repeat(self, repeats: INT64): # repeat(Tensor self, SymInt[] repeats) -> Tensor - raise NotImplementedError() + # FIXME(justinchuby): When repeats.shape == [0] + + # TODO(justinchuby): Make ones_like a function when onnxscript supports it + # shape = ones_like(repeats) := { + one = op.Constant(value_int=1) + repeats_shape = op.Shape(repeats) + shape = op.Expand(one, repeats_shape) + # } + self_expanded = op.Expand(self, shape) # type: ignore[arg-type] + return op.Tile(self_expanded, repeats) def aten_repeat_interleave( @@ -4012,10 +4066,10 @@ def aten_rot90(self: TensorType, k: int = 1, dims: Sequence[int] = (0, 1)) -> Te raise NotImplementedError() -def aten_round(self: TensorType) -> TensorType: +def aten_round(self): # round(Tensor self) -> Tensor - raise NotImplementedError() + return op.Round(self) def aten_row_indices(self: TensorType) -> TensorType: diff --git a/onnxscript/test/function_libs/torch_aten/ops_correctness_test.py b/onnxscript/test/function_libs/torch_aten/ops_correctness_test.py index 9cb81eb229..a2a9dbac7c 100644 --- a/onnxscript/test/function_libs/torch_aten/ops_correctness_test.py +++ b/onnxscript/test/function_libs/torch_aten/ops_correctness_test.py @@ -7,7 +7,9 @@ from typing import Any, Callable, Collection, Iterable, Optional, Sequence, TypeVar import numpy as np +import onnx import onnxruntime.capi.onnxruntime_pybind11_state +import parameterized import torch from torch.testing._internal import common_device_type, common_methods_invocations from torch.testing._internal.opinfo import core as opinfo_core @@ -69,14 +71,15 @@ class DecorateMeta: decorator: Callable[..., Any] dtypes: Optional[Collection[torch.dtype]] reason: str + matcher: Optional[Callable[[Any], bool]] = None def xfail( op_name: str, variant_name: str = "", *, + reason: str, dtypes: Optional[Collection[torch.dtype]] = None, - reason: Optional[str] = None, ): """Expects an OpInfo test to fail. @@ -86,8 +89,6 @@ def xfail( dtypes: The dtypes to expect the failure. reason: The reason for the failure. """ - if reason is None: - raise ValueError("Please specify a reason.") return DecorateMeta( op_name=op_name, variant_name=variant_name, @@ -101,8 +102,9 @@ def skip( op_name: str, variant_name: str = "", *, + reason: str, dtypes: Optional[Collection[torch.dtype]] = None, - reason: Optional[str] = None, + matcher: Optional[Callable[[Any], Any]] = None, ): """Skips an OpInfo test. @@ -111,15 +113,16 @@ def skip( variant_name: Optional OpInfo variant_test_name. dtypes: The dtypes to skip. reason: The reason for skipping. + matcher: A function that matches the test sample input. It is used only when + xfail is in the SKIP_SUBTESTS list. """ - if reason is None: - raise ValueError("Please specify a reason.") return DecorateMeta( op_name=op_name, variant_name=variant_name, decorator=unittest.skip(f"Don't care: {reason}"), dtypes=dtypes, reason=reason, + matcher=matcher, ) @@ -159,10 +162,18 @@ def wrapped(fn): # Find the names of the OpInfos in torch/testing/_internal/common_methods_invocations.py OPINFO_FUNCTION_MAPPING: dict[str, Callable[..., Any]] = { "add": core_ops.aten_add, + # "clamp": core_ops.aten_clamp, # TODO(justinchuby): Enable + "clamp_max": core_ops.aten_clamp_max_tensor, + "clamp_min": core_ops.aten_clamp_min_tensor, + "gt": core_ops.aten_gt, + "lt": core_ops.aten_lt, "mul": core_ops.aten_mul, "nn.functional.elu": nn_ops.aten_elu, "nn.functional.relu6": nn_ops.aten_relu6, "nn.functional.selu": core_ops.aten_selu, + "ones_like": core_ops.aten_ones_like, + "repeat": core_ops.aten_repeat, + "round": core_ops.aten_round, "sub": core_ops.aten_sub, } @@ -170,6 +181,10 @@ def wrapped(fn): EXPECTED_SKIPS_OR_FAILS = ( xfail("add", dtypes=BOOL_TYPES, reason="Add is not defined on bool tensors"), + xfail("clamp_max", dtypes=BOOL_TYPES, reason="Min is not defined on bool tensors"), + xfail("clamp_min", dtypes=BOOL_TYPES, reason="Max is not defined on bool tensors"), + xfail("gt", dtypes=BOOL_TYPES, reason="Greater is not defined on bool tensors"), + xfail("lt", dtypes=BOOL_TYPES, reason="Less is not defined on bool tensors"), xfail("mul", dtypes=BOOL_TYPES, reason="Mul is not defined on bool tensors"), xfail( "nn.functional.elu", @@ -186,14 +201,117 @@ def wrapped(fn): dtypes=dtypes_except(torch.float16, torch.float32), reason="ONNX Runtime doesn't support float64 for Selu", ), + xfail( + "round", + variant_name="", + dtypes=dtypes_except(*FLOAT_TYPES), + reason="Round is not defined on non-float tensors", + ), + xfail("round", variant_name="decimals_0", reason="The ATen op does not support decimals"), + xfail("round", variant_name="decimals_3", reason="The ATen op does not support decimals"), + xfail( + "round", variant_name="decimals_neg_3", reason="The ATen op does not support decimals" + ), xfail("sub", dtypes=BOOL_TYPES, reason="Sub is not defined on bool tensors"), ) + + +SKIP_SUBTESTS = ( + skip( + "clamp_max", + reason="Empty tensor not yet supported", + matcher=lambda sample: sample.input.size() == torch.Size([0]), + ), + skip( + "clamp_min", + reason="Empty tensor not yet supported", + matcher=lambda sample: sample.input.size() == torch.Size([0]), + ), + skip( + "repeat", + reason="repeating when input is a scalar and repeats is empty is not supported", + matcher=lambda sample: sample.args[0] == (), + ), +) +OP_WITH_SKIPPED_SUBTESTS = frozenset(meta.op_name for meta in SKIP_SUBTESTS) + # END OF SECTION TO MODIFY ##################################################### OPS_DB = copy.deepcopy(common_methods_invocations.op_db) +TORCH_TYPE_TO_ONNX = { + torch.bool: onnx.TensorProto.BOOL, + torch.uint8: onnx.TensorProto.UINT8, + torch.int8: onnx.TensorProto.INT8, + torch.int16: onnx.TensorProto.INT16, + torch.int32: onnx.TensorProto.INT32, + torch.int64: onnx.TensorProto.INT64, + torch.float16: onnx.TensorProto.FLOAT16, + torch.float32: onnx.TensorProto.FLOAT, + torch.float64: onnx.TensorProto.DOUBLE, + torch.complex64: onnx.TensorProto.COMPLEX64, + torch.complex128: onnx.TensorProto.COMPLEX128, + torch.bfloat16: onnx.TensorProto.BFLOAT16, +} + + +class TestFunctionsCompilation(unittest.TestCase): + """Test all functions can be compiled.""" + + @parameterized.parameterized.expand( + list(OPINFO_FUNCTION_MAPPING.items()), + ) + def test_function_compiles(self, _, function): + compiled = onnxscript.script()(function) + compiled.to_function_proto() + + +def _convert_tensor_to_numpy(input: Any) -> Any: + if isinstance(input, torch.Tensor): + return input.detach().cpu().numpy() + if isinstance(input, (tuple, list)): + if len(input) == 0: + return np.array((), dtype=np.int64) + if isinstance(input[0], torch.Tensor): + return [_convert_tensor_to_numpy(x) for x in input] + if isinstance(input[0], (int, float)): + # Just a tuple of numbers + return np.array(input) + return input + + return input + + +def _convert_kwargs_for_onnx(kwargs: dict[str, Any]) -> dict[str, Any]: + """Converts kwargs to be compatible with ONNX Runtime. + + ONNX Runtime doesn't support torch.bool, so we convert them to torch.uint8. + """ + new_kwargs = {} + for key, value in kwargs.items(): + if key == "device": + continue + if key == "dtype": + value = TORCH_TYPE_TO_ONNX[value] + new_kwargs[key] = value + return new_kwargs + + +def _should_skip_test_sample(op_name: str, sample) -> Optional[str]: + """Returns a reason if a test sample should be skipped.""" + if op_name not in OP_WITH_SKIPPED_SUBTESTS: + return None + for decorator_meta in SKIP_SUBTESTS: + # Linear search on SKIP_SUBTESTS. That's fine because the list is small. + if decorator_meta.op_name == op_name: + assert decorator_meta.matcher is not None, "Matcher must be defined" + if decorator_meta.matcher(sample): + return decorator_meta.reason + return None + + class TestOutputConsistency(unittest.TestCase): """Test output consistency between exported ONNX models and PyTorch eager mode. @@ -236,10 +354,14 @@ def test_output_match(self, device: str, dtype: torch.dtype, op): inputs=repr(inputs), kwargs=repr(cpu_sample.kwargs), ): - input_numpy = [x.numpy() for x in inputs if isinstance(x, torch.Tensor)] - torch_output = op(*inputs, **cpu_sample.kwargs) + skip_reason = _should_skip_test_sample(op.name, cpu_sample) + if skip_reason is not None: + self.skipTest(skip_reason) + input_onnx = [_convert_tensor_to_numpy(x) for x in inputs] + kwargs_onnx = _convert_kwargs_for_onnx(cpu_sample.kwargs) + output_torch = op(*inputs, **cpu_sample.kwargs) try: - function_output = scripted_function(*input_numpy, **cpu_sample.kwargs) + function_output = scripted_function(*input_onnx, **kwargs_onnx) # pylint: disable=c-extension-no-member except onnxruntime.capi.onnxruntime_pybind11_state.NotImplemented: self.skipTest( @@ -250,7 +372,7 @@ def test_output_match(self, device: str, dtype: torch.dtype, op): # Use torch testing to ensure dtypes and shapes match torch.testing.assert_close( torch.tensor(function_output), - torch_output, + output_torch, )