Skip to content

Commit b5ca6f7

Browse files
authored
feat(atenlib): add ops(resolve_conj, resolve_neg) (#458)
1 parent 6c5f9a8 commit b5ca6f7

File tree

2 files changed

+8
-4
lines changed

2 files changed

+8
-4
lines changed

onnxscript/function_libs/torch_aten/ops/core.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -4646,16 +4646,18 @@ def aten_reshape_as(self: TensorType, other: TensorType) -> TensorType:
46464646
raise NotImplementedError()
46474647

46484648

4649-
def aten_resolve_conj(self: TensorType) -> TensorType:
4649+
@torch_op("aten::resolve_conj")
4650+
def aten_resolve_conj(self: TTensor) -> TTensor:
46504651
"""resolve_conj(Tensor(a) self) -> Tensor(a)"""
46514652

4652-
raise NotImplementedError()
4653+
return op.Identity(self)
46534654

46544655

4655-
def aten_resolve_neg(self: TensorType) -> TensorType:
4656+
@torch_op("aten::resolve_neg")
4657+
def aten_resolve_neg(self: TTensor) -> TTensor:
46564658
"""resolve_neg(Tensor(a) self) -> Tensor(a)"""
46574659

4658-
raise NotImplementedError()
4660+
return op.Identity(self)
46594661

46604662

46614663
def aten_result_type(tensor: TensorType, other: TensorType) -> int:

onnxscript/tests/function_libs/torch_aten/ops_correctness_test.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -330,6 +330,8 @@ def _where_input_wrangler(
330330
"remainder": core_ops.aten_remainder,
331331
"repeat": core_ops.aten_repeat,
332332
"reshape": core_ops.aten_reshape,
333+
"resolve_conj": core_ops.aten_resolve_conj,
334+
"resolve_neg": core_ops.aten_resolve_neg,
333335
"round": core_ops.aten_round,
334336
"rsqrt": core_ops.aten_rsqrt,
335337
"rsub": core_ops.aten_rsub,

0 commit comments

Comments
 (0)