diff --git a/onnxscript/function_libs/torch_aten/ops/core.py b/onnxscript/function_libs/torch_aten/ops/core.py index c789ccd942..5201fb28a2 100644 --- a/onnxscript/function_libs/torch_aten/ops/core.py +++ b/onnxscript/function_libs/torch_aten/ops/core.py @@ -332,22 +332,22 @@ def aten_as_strided_scatter( raise NotImplementedError() -def aten_asin(self: TensorType) -> TensorType: +def aten_asin(self): # asin(Tensor self) -> Tensor - raise NotImplementedError() + return op.Asin(self) -def aten_asinh(self: TensorType) -> TensorType: +def aten_asinh(self): # asinh(Tensor self) -> Tensor - raise NotImplementedError() + return op.Asinh(self) -def aten_atan(self: TensorType) -> TensorType: +def aten_atan(self): # atan(Tensor self) -> Tensor - raise NotImplementedError() + return op.Atan(self) def aten_atan2(self: TensorType, other: TensorType) -> TensorType: @@ -356,10 +356,10 @@ def aten_atan2(self: TensorType, other: TensorType) -> TensorType: raise NotImplementedError() -def aten_atanh(self: TensorType) -> TensorType: +def aten_atanh(self): # atanh(Tensor self) -> Tensor - raise NotImplementedError() + return op.Atanh(self) def aten_atleast_1d(self: TensorType) -> TensorType: @@ -670,16 +670,10 @@ def aten_cdist( raise NotImplementedError() -def aten_ceil(self: TensorType) -> TensorType: +def aten_ceil(self): # ceil(Tensor self) -> Tensor - raise NotImplementedError() - - -def aten_celu(self: TensorType, alpha: float = 1.0) -> TensorType: - # celu(Tensor self, Scalar alpha=1.0) -> Tensor - - raise NotImplementedError() + return op.Ceil(self) def aten_chain_matmul(matrices: Sequence[TensorType]) -> TensorType: @@ -785,14 +779,6 @@ def aten_clamp_min_tensor(self, min_): return op.Max(self, min_) -def aten_clip( - self: TensorType, min: Optional[float] = None, max: Optional[float] = None -) -> TensorType: - # clip(Tensor self, Scalar? min=None, Scalar? max=None) -> Tensor - - raise NotImplementedError() - - def aten_clone(self: TensorType, memory_format: Optional[str] = None) -> TensorType: # clone(Tensor self, *, MemoryFormat? memory_format=None) -> Tensor @@ -1031,16 +1017,16 @@ def aten_corrcoef(self: TensorType) -> TensorType: raise NotImplementedError() -def aten_cos(self: TensorType) -> TensorType: +def aten_cos(self): # cos(Tensor self) -> Tensor - raise NotImplementedError() + return op.Cos(self) -def aten_cosh(self: TensorType) -> TensorType: +def aten_cosh(self): # cosh(Tensor self) -> Tensor - raise NotImplementedError() + return op.Cosh(self) def aten_cosine_embedding_loss( @@ -1406,10 +1392,10 @@ def aten_divide(self: TensorType, other: TensorType) -> TensorType: raise NotImplementedError() -def aten_dot(self: TensorType, tensor: TensorType) -> TensorType: +def aten_dot(self, tensor): # dot(Tensor self, Tensor tensor) -> Tensor - raise NotImplementedError() + return op.MatMul(self, tensor) def aten_dropout(input: TensorType, p: float, train: bool) -> TensorType: @@ -1546,16 +1532,18 @@ def aten_erfinv(self: TensorType) -> TensorType: raise NotImplementedError() -def aten_exp(self: TensorType) -> TensorType: +def aten_exp(self): # exp(Tensor self) -> Tensor - raise NotImplementedError() + return op.Exp(self) -def aten_exp2(self: TensorType) -> TensorType: +def aten_exp2(self): # exp2(Tensor self) -> Tensor - raise NotImplementedError() + two = op.Constant(value_int=2) + two = op.CastLike(two, self) # type: ignore[arg-type] + return op.Pow(two, self) # type: ignore[arg-type] def aten_expand(self: TensorType, size: INT64, implicit: bool = False) -> TensorType: @@ -4205,22 +4193,16 @@ def aten_signbit(self: TensorType) -> TensorType: raise NotImplementedError() -def aten_sin(self: TensorType) -> TensorType: +def aten_sin(self): # sin(Tensor self) -> Tensor - raise NotImplementedError() - + return op.Sin(self) -def aten_sinc(self: TensorType) -> TensorType: - # sinc(Tensor self) -> Tensor - raise NotImplementedError() - - -def aten_sinh(self: TensorType) -> TensorType: +def aten_sinh(self): # sinh(Tensor self) -> Tensor - raise NotImplementedError() + return op.Sinh(self) def aten_slice( @@ -4483,16 +4465,16 @@ def aten_take_along_dim( raise NotImplementedError() -def aten_tan(self: TensorType) -> TensorType: +def aten_tan(self): # tan(Tensor self) -> Tensor - raise NotImplementedError() + return op.Tan(self) -def aten_tanh(self: TensorType) -> TensorType: +def aten_tanh(self): # tanh(Tensor self) -> Tensor - raise NotImplementedError() + return op.Tanh(self) def aten_tensordot( diff --git a/onnxscript/function_libs/torch_aten/ops/nn.py b/onnxscript/function_libs/torch_aten/ops/nn.py index d76250ea8a..0a1968d097 100644 --- a/onnxscript/function_libs/torch_aten/ops/nn.py +++ b/onnxscript/function_libs/torch_aten/ops/nn.py @@ -150,6 +150,12 @@ def aten_binary_cross_entropy_backward( raise NotImplementedError() +def aten_celu(self, alpha: float = 1.0): + # celu(Tensor self, Scalar alpha=1.0) -> Tensor + + raise NotImplementedError() + + def aten_col2im( self: TensorType, output_size: INT64, diff --git a/onnxscript/test/function_libs/torch_aten/ops_correctness_test.py b/onnxscript/test/function_libs/torch_aten/ops_correctness_test.py index c813c2eb76..aaee264208 100644 --- a/onnxscript/test/function_libs/torch_aten/ops_correctness_test.py +++ b/onnxscript/test/function_libs/torch_aten/ops_correctness_test.py @@ -119,7 +119,7 @@ def skip( return DecorateMeta( op_name=op_name, variant_name=variant_name, - decorator=unittest.skip(f"Don't care: {reason}"), + decorator=unittest.skip(f"Skip: {reason}"), dtypes=dtypes, reason=reason, matcher=matcher, @@ -166,10 +166,20 @@ def wrapped(fn): "acosh": core_ops.aten_acosh, "add": core_ops.aten_add, "addmm": core_ops.aten_addmm, + "asin": core_ops.aten_asin, + "asinh": core_ops.aten_asinh, + "atan": core_ops.aten_atan, + "atanh": core_ops.aten_atanh, "bmm": core_ops.aten_bmm, + "ceil": core_ops.aten_ceil, "clamp_max": core_ops.aten_clamp_max_tensor, "clamp_min": core_ops.aten_clamp_min_tensor, "clamp": core_ops.aten_clamp, + "cos": core_ops.aten_cos, + "cosh": core_ops.aten_cosh, + "dot": core_ops.aten_dot, + "exp": core_ops.aten_exp, + "exp2": core_ops.aten_exp2, "gt": core_ops.aten_gt, "lt": core_ops.aten_lt, "matmul": core_ops.aten_matmul, @@ -183,8 +193,12 @@ def wrapped(fn): "ones": core_ops.aten_ones, "repeat": core_ops.aten_repeat, "round": core_ops.aten_round, + "sin": core_ops.aten_sin, + "sinh": core_ops.aten_sinh, "sub": core_ops.aten_sub, "t": core_ops.aten_t, + "tan": core_ops.aten_tan, + "tanh": core_ops.aten_tanh, # "transpose": core_ops.aten_transpose, # TODO(justinchuby): Enable when onnxscript errors are fixed } @@ -206,6 +220,7 @@ def wrapped(fn): "addmm", dtypes=[torch.uint8, torch.int8, torch.int16], reason="MatMul is not defined on int16/int8/uint8 tensors", + # TODO(justinchuby): Use MatMulInteger ), xfail( "addmm", @@ -213,14 +228,64 @@ def wrapped(fn): dtypes=[torch.uint8, torch.int8, torch.int16], reason="MatMul is not defined on int16/int8/uint8 tensors", ), + xfail( + "asin", + dtypes=BOOL_TYPES + INT_TYPES, + reason="Asin is not defined on bool or int tensors", + ), + xfail( + "asinh", + dtypes=BOOL_TYPES + INT_TYPES, + reason="Asinh is not defined on bool or int tensors", + ), + xfail( + "atan", + dtypes=BOOL_TYPES + INT_TYPES, + reason="Atan is not defined on bool or int tensors", + ), + xfail( + "atanh", + dtypes=BOOL_TYPES + INT_TYPES, + reason="Atanh is not defined on bool or int tensors", + ), xfail( "bmm", dtypes=[torch.uint8, torch.int8, torch.int16], reason="MatMul is not defined on int16/int8/uint8 tensors", ), + xfail( + "ceil", + dtypes=BOOL_TYPES + INT_TYPES, + reason="Ceil is not defined on bool or int tensors", + ), skip("clamp", reason="Enable when onnxscript errors are fixed"), xfail("clamp_max", dtypes=BOOL_TYPES, reason="Min is not defined on bool tensors"), xfail("clamp_min", dtypes=BOOL_TYPES, reason="Max is not defined on bool tensors"), + xfail( + "cos", + dtypes=BOOL_TYPES + INT_TYPES, + reason="Cos is not defined on bool or int tensors", + ), + xfail( + "cosh", + dtypes=BOOL_TYPES + INT_TYPES, + reason="Cosh is not defined on bool or int tensors", + ), + xfail( + "dot", + dtypes=[torch.uint8, torch.int8, torch.int16], + reason="MatMul is not defined on int16/int8/uint8 tensors", + ), + xfail( + "exp", + dtypes=BOOL_TYPES + INT_TYPES, + reason="Exp is not defined on bool or int tensors", + ), + xfail( + "exp2", + dtypes=BOOL_TYPES + INT_TYPES, + reason="Pow is not defined on bool or int tensors", + ), xfail("gt", dtypes=BOOL_TYPES, reason="Greater is not defined on bool tensors"), xfail("lt", dtypes=BOOL_TYPES, reason="Less is not defined on bool tensors"), xfail( @@ -264,7 +329,27 @@ def wrapped(fn): xfail( "round", variant_name="decimals_neg_3", reason="The ATen op does not support decimals" ), + xfail( + "sin", + dtypes=BOOL_TYPES + INT_TYPES, + reason="Sin is not defined on bool or int tensors", + ), + xfail( + "sinh", + dtypes=BOOL_TYPES + INT_TYPES, + reason="Sinh is not defined on bool or int tensors", + ), xfail("sub", dtypes=BOOL_TYPES, reason="Sub is not defined on bool tensors"), + xfail( + "tan", + dtypes=BOOL_TYPES + INT_TYPES, + reason="Tan is not defined on bool or int tensors", + ), + xfail( + "tanh", + dtypes=BOOL_TYPES + INT_TYPES, + reason="Tanh is not defined on bool or int tensors", + ), xfail("transpose", reason="Enable when onnxscript errors are fixed"), )