From c7598e2dd683948552e16d7b78d2cfdbb75b2d13 Mon Sep 17 00:00:00 2001 From: xiaowuhu Date: Wed, 22 Feb 2023 17:05:05 +0800 Subject: [PATCH 1/2] add ops --- onnxscript/function_libs/torch_aten/ops/core.py | 10 ++++++---- .../function_libs/torch_aten/ops_correctness_test.py | 2 ++ 2 files changed, 8 insertions(+), 4 deletions(-) diff --git a/onnxscript/function_libs/torch_aten/ops/core.py b/onnxscript/function_libs/torch_aten/ops/core.py index 0a77c229b0..06bdaebe98 100644 --- a/onnxscript/function_libs/torch_aten/ops/core.py +++ b/onnxscript/function_libs/torch_aten/ops/core.py @@ -4641,16 +4641,18 @@ def aten_reshape_as(self: TensorType, other: TensorType) -> TensorType: raise NotImplementedError() -def aten_resolve_conj(self: TensorType) -> TensorType: +@torch_op("aten::resolve_conj") +def aten_resolve_conj(self: TTensor) -> TensorType: """resolve_conj(Tensor(a) self) -> Tensor(a)""" - raise NotImplementedError() + return op.Identity(self) -def aten_resolve_neg(self: TensorType) -> TensorType: +@torch_op("aten::resolve_neg") +def aten_resolve_neg(self: TTensor) -> TensorType: """resolve_neg(Tensor(a) self) -> Tensor(a)""" - raise NotImplementedError() + return op.Identity(self) def aten_result_type(tensor: TensorType, other: TensorType) -> int: diff --git a/onnxscript/tests/function_libs/torch_aten/ops_correctness_test.py b/onnxscript/tests/function_libs/torch_aten/ops_correctness_test.py index 7bbc401ab8..994d7dc482 100644 --- a/onnxscript/tests/function_libs/torch_aten/ops_correctness_test.py +++ b/onnxscript/tests/function_libs/torch_aten/ops_correctness_test.py @@ -334,6 +334,8 @@ def _where_input_wrangler( "remainder": core_ops.aten_remainder, "repeat": core_ops.aten_repeat, "reshape": core_ops.aten_reshape, + "resolve_conj": core_ops.aten_resolve_conj, + "resolve_neg": core_ops.aten_resolve_neg, "round": core_ops.aten_round, "rsqrt": core_ops.aten_rsqrt, "rsub": core_ops.aten_rsub, From 4bfb954dbcbfb6bff6c5419b8af74e337793b09d Mon Sep 17 00:00:00 2001 From: xiaowuhu Date: Thu, 23 Feb 2023 12:23:30 +0800 Subject: [PATCH 2/2] Update core.py --- onnxscript/function_libs/torch_aten/ops/core.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/onnxscript/function_libs/torch_aten/ops/core.py b/onnxscript/function_libs/torch_aten/ops/core.py index 06bdaebe98..f502a35858 100644 --- a/onnxscript/function_libs/torch_aten/ops/core.py +++ b/onnxscript/function_libs/torch_aten/ops/core.py @@ -4642,14 +4642,14 @@ def aten_reshape_as(self: TensorType, other: TensorType) -> TensorType: @torch_op("aten::resolve_conj") -def aten_resolve_conj(self: TTensor) -> TensorType: +def aten_resolve_conj(self: TTensor) -> TTensor: """resolve_conj(Tensor(a) self) -> Tensor(a)""" return op.Identity(self) @torch_op("aten::resolve_neg") -def aten_resolve_neg(self: TTensor) -> TensorType: +def aten_resolve_neg(self: TTensor) -> TTensor: """resolve_neg(Tensor(a) self) -> Tensor(a)""" return op.Identity(self)