Skip to content

Commit 06a0a5c

Browse files
Split arg{max,min} into arg{max,min}_dim and improve function signatures | feat(atenlib) (#677)
To remove trace_only and allow the dispatcher to select the correct function, we explicitly split argmax with argmax_dim and fixed the function return type. --- One potential problem as below code: ```python @torch_op("aten::argmax", overload=True) def aten_argmax_dim(self: TReal, dim: int, keepdim: bool = False) -> TInt: self_is_scaler = op.Size(op.Shape(self)) == 0 if self_is_scaler: self = op.Reshape(self, op.Constant(value_ints=[-1])) result = op.ArgMax(self, axis=dim, keepdims=keepdim) if self_is_scaler: result = op.Squeeze(result) return result ``` The above code works well. But when it was rewritten to below (reduce one if/else), the ShapeInference will fail: ```python @torch_op("aten::argmax", overload=True) def aten_argmax_dim(self: TReal, dim: int, keepdim: bool = False) -> TInt: self_is_scaler = op.Size(op.Shape(self)) == 0 if self_is_scaler: self = op.Reshape(self, op.Constant(value_ints=[-1])) result = op.ArgMax(self, axis=dim, keepdims=keepdim) result = op.Squeeze(result) else: result = op.ArgMax(self, axis=dim, keepdims=keepdim) return result ``` --------- Co-authored-by: Justin Chu <justinchu@microsoft.com>
1 parent 9c69053 commit 06a0a5c

2 files changed

Lines changed: 59 additions & 21 deletions

File tree

onnxscript/function_libs/torch_lib/ops/core.py

Lines changed: 21 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414
import math
1515
from typing import Any, Optional, Sequence, Tuple, Union
1616

17-
from onnxscript import BOOL, DOUBLE, FLOAT, INT8, INT16, INT32, INT64, graph
17+
from onnxscript import BOOL, DOUBLE, FLOAT, INT8, INT16, INT32, INT64, UINT8, graph
1818
from onnxscript.function_libs.torch_lib.registration import torch_op
1919
from onnxscript.function_libs.torch_lib.tensor_typing import (
2020
IntType,
@@ -520,20 +520,21 @@ def aten_arctanh(self: TensorType) -> TensorType:
520520
raise NotImplementedError()
521521

522522

523-
@torch_op("aten::argmax", trace_only=True)
524-
def aten_argmax(
525-
self: TRealOrUInt8, dim: Optional[int] = None, keepdim: bool = False
526-
) -> TRealOrUInt8:
523+
@torch_op("aten::argmax")
524+
def aten_argmax(self: Union[RealType, UINT8], keepdim: bool = False) -> INT64:
527525
"""argmax(Tensor self, int? dim=None, bool keepdim=False) -> Tensor"""
528526

529-
if dim is None: # TODO: use OptionalHasElement(dim)
530-
self = op.Reshape(self, op.Constant(value_ints=[-1]))
527+
self_is_scaler = op.Size(op.Shape(self)) == 0
528+
self = op.Reshape(self, op.Constant(value_ints=[-1]))
529+
result = op.ArgMax(self, keepdims=keepdim)
530+
if self_is_scaler:
531+
result = op.Squeeze(result)
531532

532-
return _aten_argmax_dim(self, dim=dim, keepdim=keepdim)
533+
return result
533534

534535

535-
@torch_op("aten::argmax", private=True)
536-
def _aten_argmax_dim(self: TRealOrUInt8, dim: int, keepdim: bool = False) -> TRealOrUInt8:
536+
@torch_op("aten::argmax")
537+
def aten_argmax_dim(self: Union[RealType, UINT8], dim: int, keepdim: bool = False) -> INT64:
537538
"""argmax(Tensor self, int? dim=None, bool keepdim=False) -> Tensor"""
538539

539540
self_is_scaler = op.Size(op.Shape(self)) == 0
@@ -547,20 +548,21 @@ def _aten_argmax_dim(self: TRealOrUInt8, dim: int, keepdim: bool = False) -> TRe
547548
return result
548549

549550

550-
@torch_op("aten::argmin", trace_only=True)
551-
def aten_argmin(
552-
self: TRealOrUInt8, dim: Optional[int] = None, keepdim: bool = False
553-
) -> TRealOrUInt8:
551+
@torch_op("aten::argmin")
552+
def aten_argmin(self: Union[RealType, UINT8], keepdim: bool = False) -> INT64:
554553
"""argmin(Tensor self, int? dim=None, bool keepdim=False) -> Tensor"""
555554

556-
if dim is None: # TODO: use OptionalHasElement(dim)
557-
self = op.Reshape(self, op.Constant(value_ints=[-1]))
555+
self_is_scaler = op.Size(op.Shape(self)) == 0
556+
self = op.Reshape(self, op.Constant(value_ints=[-1]))
557+
result = op.ArgMin(self, keepdims=keepdim)
558+
if self_is_scaler:
559+
result = op.Squeeze(result)
558560

559-
return _aten_argmin_dim(self, dim=dim, keepdim=keepdim)
561+
return result
560562

561563

562-
@torch_op("aten::argmin", private=True)
563-
def _aten_argmin_dim(self: TRealOrUInt8, dim: int, keepdim: bool = False) -> TRealOrUInt8:
564+
@torch_op("aten::argmin")
565+
def aten_argmin_dim(self: Union[RealType, UINT8], dim: int, keepdim: bool = False) -> INT64:
564566
"""argmin(Tensor self, int? dim=None, bool keepdim=False) -> Tensor"""
565567

566568
self_is_scaler = op.Size(op.Shape(self)) == 0

onnxscript/tests/function_libs/torch_lib/ops_test_data.py

Lines changed: 38 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1209,7 +1209,24 @@ def _where_input_wrangler(
12091209
matcher=lambda sample: sample.kwargs.get("end") is not None,
12101210
reason="arange overload does not support positional 'end' argument",
12111211
),
1212-
TorchLibOpInfo("argmax", core_ops.aten_argmax, trace_only=True)
1212+
TorchLibOpInfo("argmax", core_ops.aten_argmax)
1213+
.skip(
1214+
matcher=lambda sample: "dim" in sample.kwargs,
1215+
reason="this overload does not support the 'dim' attribute by design",
1216+
)
1217+
.skip(
1218+
enabled_if=ops_test_common.IS_WINDOWS,
1219+
reason="fixme: ORT has memory errors. https://github.com/microsoft/onnxruntime/issues/16492",
1220+
)
1221+
.xfail(
1222+
dtypes=(torch.int64,),
1223+
reason="fixme: ORT did not implement ArgMax for int64. https://github.com/microsoft/onnxruntime/issues/16654",
1224+
),
1225+
TorchLibOpInfo("argmax_dim", core_ops.aten_argmax_dim)
1226+
.xfail(
1227+
matcher=lambda sample: "dim" not in sample.kwargs,
1228+
reason="this overload requires the 'dim' attribute by design",
1229+
)
12131230
.skip(
12141231
enabled_if=ops_test_common.IS_WINDOWS,
12151232
reason="fixme: ORT has memory errors. https://github.com/microsoft/onnxruntime/issues/16492",
@@ -1218,7 +1235,24 @@ def _where_input_wrangler(
12181235
dtypes=(torch.int64,),
12191236
reason="fixme: ORT did not implement ArgMax for int64. https://github.com/microsoft/onnxruntime/issues/16654",
12201237
),
1221-
TorchLibOpInfo("argmin", core_ops.aten_argmin, trace_only=True)
1238+
TorchLibOpInfo("argmin", core_ops.aten_argmin)
1239+
.skip(
1240+
matcher=lambda sample: "dim" in sample.kwargs,
1241+
reason="this overload does not support the 'dim' attribute by design",
1242+
)
1243+
.skip(
1244+
enabled_if=ops_test_common.IS_WINDOWS,
1245+
reason="fixme: ORT has memory errors. https://github.com/microsoft/onnxruntime/issues/16492",
1246+
)
1247+
.xfail(
1248+
dtypes=(torch.int64,),
1249+
reason="fixme: ORT did not implement ArgMin for int64. https://github.com/microsoft/onnxruntime/issues/16654",
1250+
),
1251+
TorchLibOpInfo("argmin_dim", core_ops.aten_argmin_dim)
1252+
.xfail(
1253+
matcher=lambda sample: "dim" not in sample.kwargs,
1254+
reason="this overload requires the 'dim' attribute by design",
1255+
)
12221256
.skip(
12231257
enabled_if=ops_test_common.IS_WINDOWS,
12241258
reason="fixme: ORT has memory errors. https://github.com/microsoft/onnxruntime/issues/16492",
@@ -1631,6 +1665,8 @@ def _where_input_wrangler(
16311665
ops_test_common.duplicate_opinfo(OPS_DB, "all", ("all_dim",))
16321666
ops_test_common.duplicate_opinfo(OPS_DB, "any", ("any_dim",))
16331667
ops_test_common.duplicate_opinfo(OPS_DB, "arange", ("arange_start", "arange_start_step"))
1668+
ops_test_common.duplicate_opinfo(OPS_DB, "argmax", ("argmax_dim",))
1669+
ops_test_common.duplicate_opinfo(OPS_DB, "argmin", ("argmin_dim",))
16341670
ops_test_common.duplicate_opinfo(OPS_DB, "atleast_1d", ("atleast_1d_single_tensor",))
16351671
ops_test_common.duplicate_opinfo(OPS_DB, "atleast_2d", ("atleast_2d_single_tensor",))
16361672
ops_test_common.duplicate_opinfo(OPS_DB, "atleast_3d", ("atleast_3d_single_tensor",))

0 commit comments

Comments
 (0)