diff --git a/onnxscript/function_libs/torch_aten/ops/core.py b/onnxscript/function_libs/torch_aten/ops/core.py index 90b1ae453e..e9e45f7f15 100644 --- a/onnxscript/function_libs/torch_aten/ops/core.py +++ b/onnxscript/function_libs/torch_aten/ops/core.py @@ -803,10 +803,13 @@ def aten_clamp_min(self: TReal, min_: TReal) -> TReal: return result -def aten_clone(self: TensorType, memory_format: Optional[str] = None) -> TensorType: +@torch_op("aten::clone") +def aten_clone( + self: TTensor, memory_format: str = "" # pylint: disable=unused-argument +) -> TTensor: # clone(Tensor self, *, MemoryFormat? memory_format=None) -> Tensor - raise NotImplementedError() + return op.Identity(self) def aten_coalesce(self: TensorType) -> TensorType: @@ -1406,10 +1409,11 @@ def aten_dist(self: TensorType, other: TensorType, p: float = 2) -> TensorType: raise NotImplementedError() -def aten_div(self: TensorType, other: TensorType) -> TensorType: +@torch_op("aten::div") +def aten_div(self: TReal, other: TReal) -> TReal: # div.Tensor(Tensor self, Tensor other) -> Tensor - raise NotImplementedError() + return op.Div(self, other) def aten_divide(self: TensorType, other: TensorType) -> TensorType: @@ -1529,16 +1533,21 @@ def aten_empty_strided(size: INT64, stride: INT64) -> TensorType: raise NotImplementedError() -def aten_eq(self: TensorType, other: TensorType) -> TensorType: +@torch_op("aten::eq") +def aten_eq(self: TTensor, other: TTensor) -> BOOL: # eq.Tensor(Tensor self, Tensor other) -> Tensor - raise NotImplementedError() + return op.Equal(self, other) -def aten_equal(self: TensorType, other: TensorType) -> bool: +@torch_op("aten::equal") +def aten_equal(self: TTensor, other: TTensor) -> BOOL: # equal(Tensor self, Tensor other) -> bool - raise NotImplementedError() + sub_self_other = op.Sub(self, other) + abs_sub = op.Abs(sub_self_other) + sum_of_abs = op.ReduceSum(abs_sub, keepdims=0) + return op.Equal(sum_of_abs, 0) @torch_op("aten::erf") @@ -1576,10 +1585,12 @@ def aten_exp2(self: TFloat) -> TFloat: return op.Pow(two, self) -def aten_expand(self: TensorType, size: INT64, implicit: bool = False) -> TensorType: +@torch_op("aten::expand") +def aten_expand(self: TTensor, size: INT64) -> TTensor: # expand(Tensor(a) self, SymInt[] size, *, bool implicit=False) -> Tensor(a) - raise NotImplementedError() + size = op.Cast(size, to=INT64.dtype) # to INT64 + return op.Expand(self, size) def aten_expand_as(self: TensorType, other: TensorType) -> TensorType: @@ -4046,10 +4057,12 @@ def aten_repeat_interleave( raise NotImplementedError() -def aten_reshape(self: TensorType, shape: INT64) -> TensorType: +@torch_op("aten::reshape") +def aten_reshape(self: TTensor, shape: INT64) -> TTensor: # reshape(Tensor(a) self, SymInt[] shape) -> Tensor(a) - raise NotImplementedError() + shape = op.Cast(shape, to=INT64.dtype) # Reshape only support INT64 as 'shape' + return op.Reshape(self, shape) def aten_reshape_as(self: TensorType, other: TensorType) -> TensorType: @@ -4484,7 +4497,9 @@ def aten_subtract(self: TensorType, other: TensorType, alpha: float = 1) -> Tens raise NotImplementedError() -def aten_sum(self: TensorType, dtype: Optional[int] = None) -> TensorType: +def aten_sum( + self: TensorType, dim: Optional[int] = None, keepdim: bool = False, dtype: int = -1 +) -> TensorType: # sum(Tensor self, *, ScalarType? dtype=None) -> Tensor raise NotImplementedError() @@ -4903,10 +4918,12 @@ def aten_vdot(self: TensorType, other: TensorType) -> TensorType: raise NotImplementedError() -def aten_view(self: TensorType, size: INT64) -> TensorType: +@torch_op("aten::view") +def aten_view(self: TTensor, size: INT64) -> TTensor: # view(Tensor(a) self, SymInt[] size) -> Tensor(a) - raise NotImplementedError() + size = op.Cast(size, to=INT64.dtype) # Reshape only support INT64 as second input + return op.Reshape(self, size) def aten_view_as(self: TensorType, other: TensorType) -> TensorType: diff --git a/onnxscript/function_libs/torch_aten/ops/special.py b/onnxscript/function_libs/torch_aten/ops/special.py index 74c3001eea..0b88e81506 100644 --- a/onnxscript/function_libs/torch_aten/ops/special.py +++ b/onnxscript/function_libs/torch_aten/ops/special.py @@ -202,9 +202,7 @@ def aten_special_log_ndtr(self: TensorType) -> TensorType: raise NotImplementedError() -def aten_special_log_softmax( - self: TensorType, dim: int, dtype: Optional[int] = None -) -> TensorType: +def aten_special_log_softmax(self: TensorType, dim: int, dtype: int = -1) -> TensorType: # special_log_softmax(Tensor self, int dim, *, ScalarType? dtype=None) -> Tensor raise NotImplementedError() diff --git a/onnxscript/test/function_libs/torch_aten/ops_correctness_test.py b/onnxscript/test/function_libs/torch_aten/ops_correctness_test.py index d5d167ab83..ad9d629e8b 100644 --- a/onnxscript/test/function_libs/torch_aten/ops_correctness_test.py +++ b/onnxscript/test/function_libs/torch_aten/ops_correctness_test.py @@ -204,12 +204,17 @@ def _upsample_kwargs_wrangler(kwargs: dict[str, Any]) -> dict[str, Any]: "clamp_max": core_ops.aten_clamp_max, "clamp_min": core_ops.aten_clamp_min, "clamp": core_ops.aten_clamp, + "clone": core_ops.aten_clone, "cos": core_ops.aten_cos, "cosh": core_ops.aten_cosh, + "div": core_ops.aten_div, "dot": core_ops.aten_dot, - "erf": core_ops.aten_erf, + "eq": core_ops.aten_eq, + "equal": core_ops.aten_equal, "exp": core_ops.aten_exp, "exp2": core_ops.aten_exp2, + "expand": core_ops.aten_expand, + "erf": core_ops.aten_erf, "fmod": core_ops.aten_fmod, # TODO(justinchuby): Test aten::full "full_like": core_ops.aten_full_like, @@ -242,6 +247,7 @@ def _upsample_kwargs_wrangler(kwargs: dict[str, Any]) -> dict[str, Any]: "reciprocal": core_ops.aten_reciprocal, "remainder": core_ops.aten_remainder, "repeat": core_ops.aten_repeat, + "reshape": core_ops.aten_reshape, "round": core_ops.aten_round, "rsqrt": core_ops.aten_rsqrt, "rsub": core_ops.aten_rsub, @@ -256,6 +262,7 @@ def _upsample_kwargs_wrangler(kwargs: dict[str, Any]) -> dict[str, Any]: "tanh": core_ops.aten_tanh, "transpose": core_ops.aten_transpose, "unsqueeze": core_ops.aten_unsqueeze, + "view": core_ops.aten_view, "where": core_ops.aten_where, "zeros": core_ops.aten_zeros, "zeros_like": core_ops.aten_zeros_like, @@ -282,6 +289,16 @@ def _upsample_kwargs_wrangler(kwargs: dict[str, Any]) -> dict[str, Any]: SKIP_SUBTESTS: tuple[DecorateMeta, ...] = ( + skip( + "div", + matcher=lambda sample: sample.kwargs.get("rounding_mode") is not None, + reason="rounding_mode is not yet supported", + ), + skip( + "expand", + matcher=lambda sample: (np.array(sample.args[0]) > 0).all() is np.bool_(False), + reason="Negative value is not supported", + ), skip( "nonzero", matcher=lambda sample: sample.kwargs.get("as_tuple") is not None,