Skip to content

Commit 000d657

Browse files
authored
feat(atenlib): add ops (ArgMax, ArgMin, Detach) (#324)
add ops: argmax, argmin, but use trace_only = True mode, due to OptionalHasElement(dim) cannot work now.
1 parent be4ea50 commit 000d657

File tree

2 files changed

+31
-9
lines changed

2 files changed

+31
-9
lines changed

onnxscript/function_libs/torch_aten/ops/core.py

Lines changed: 28 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -331,20 +331,38 @@ def aten_arctanh(self: TensorType) -> TensorType:
331331
raise NotImplementedError()
332332

333333

334-
def aten_argmax(
335-
self: TensorType, dim: Optional[int] = None, keepdim: bool = False
336-
) -> TensorType:
334+
@torch_op("aten::argmax", trace_only=True)
335+
def aten_argmax(self: TReal, dim: Optional[int] = None, keepdim: bool = False) -> TReal:
337336
# argmax(Tensor self, int? dim=None, bool keepdim=False) -> Tensor
338337

339-
raise NotImplementedError()
338+
self_is_scaler = op.Size(op.Shape(self)) == 0
339+
if self_is_scaler:
340+
self = op.Reshape(self, op.Constant(value_ints=[-1]))
341+
elif dim is None: # should use OptionalHasElement(dim)
342+
self = op.Reshape(self, op.Constant(value_ints=[-1]))
340343

344+
result = op.ArgMax(self, axis=dim, keepdims=keepdim)
345+
if self_is_scaler:
346+
result = op.Squeeze(result)
341347

342-
def aten_argmin(
343-
self: TensorType, dim: Optional[int] = None, keepdim: bool = False
344-
) -> TensorType:
348+
return result
349+
350+
351+
@torch_op("aten::argmin", trace_only=True)
352+
def aten_argmin(self: TReal, dim: Optional[int] = None, keepdim: bool = False) -> TReal:
345353
# argmin(Tensor self, int? dim=None, bool keepdim=False) -> Tensor
346354

347-
raise NotImplementedError()
355+
self_is_scaler = op.Size(op.Shape(self)) == 0
356+
if self_is_scaler:
357+
self = op.Reshape(self, op.Constant(value_ints=[-1]))
358+
elif dim is None: # should use OptionalHasElement(dim)
359+
self = op.Reshape(self, op.Constant(value_ints=[-1]))
360+
361+
result = op.ArgMin(self, axis=dim, keepdims=keepdim)
362+
if self_is_scaler:
363+
result = op.Squeeze(result)
364+
365+
return result
348366

349367

350368
def aten_argsort(self: TensorType, dim: int = -1, descending: bool = False) -> TensorType:
@@ -1383,10 +1401,11 @@ def aten_det(self: TensorType) -> TensorType:
13831401
raise NotImplementedError()
13841402

13851403

1404+
@torch_op("aten::detach")
13861405
def aten_detach(self: TensorType) -> TensorType:
13871406
# detach(Tensor(a) self) -> Tensor(a)
13881407

1389-
raise NotImplementedError()
1408+
return op.Identity(self)
13901409

13911410

13921411
def aten_detach_copy(self: TensorType) -> TensorType:

onnxscript/test/function_libs/torch_aten/ops_correctness_test.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -273,6 +273,7 @@ def _topk_input_wrangler(
273273
"clone": core_ops.aten_clone,
274274
"cos": core_ops.aten_cos,
275275
"cosh": core_ops.aten_cosh,
276+
# "detach": core_ops.aten_detach, # detach is not in OP-TEST-DB
276277
"div": core_ops.aten_div,
277278
"dot": core_ops.aten_dot,
278279
"empty": core_ops.aten_empty,
@@ -364,6 +365,8 @@ def _topk_input_wrangler(
364365
str,
365366
Callable[..., Any] | tuple[Callable[..., Any], Callable[..., Any]],
366367
] = {
368+
"argmax": core_ops.aten_argmax,
369+
"argmin": core_ops.aten_argmin,
367370
"cat": core_ops.aten_cat,
368371
"index_select": core_ops.aten_index_select,
369372
"transpose": core_ops.aten_transpose,

0 commit comments

Comments
 (0)