diff --git a/onnxscript/function_libs/torch_aten/ops/core.py b/onnxscript/function_libs/torch_aten/ops/core.py index 712791f934..91c834fee9 100644 --- a/onnxscript/function_libs/torch_aten/ops/core.py +++ b/onnxscript/function_libs/torch_aten/ops/core.py @@ -3257,10 +3257,58 @@ def aten_matrix_power(self: TensorType, n: int) -> TensorType: raise NotImplementedError() -def aten_max(self: TensorType) -> TensorType: +@torch_op("aten::max", trace_only=True) +def aten_max( + self: TReal, dim_or_other: Union[TReal, INT64] = None, keepdim: BOOL = None +) -> TReal: """max(Tensor self) -> Tensor""" - raise NotImplementedError() + self_rank = op.Size(op.Shape(self)) + if self_rank == 0: + self = op.Reshape(self, op.Constant(value_int=[-1])) + + output = 1 + + if op.OptionalHasElement(dim_or_other): + if isinstance(dim_or_other, int): + if not op.OptionalHasElement(keepdim): + keepdim = False + result, indices = _aten_max_with_dim(self, dim_or_other, keepdim) + output = 2 + else: # dim_or_other is tensor + result = _aten_max_with_other(self, dim_or_other) + else: + result = _aten_max_with_no_dim(self) + + if self_rank == 0: + result = op.Squeeze(result) + + if output == 2: + if self_rank == 0: + indices = op.Squeeze(indices) # type: ignore[has-type] + return result, indices + return result + + +@torch_op("aten::max", overload=True) +def _aten_max_with_no_dim(self: TReal) -> TReal: + result = op.ReduceMax(self, keepdims=0) + return result + + +@torch_op("aten::max", overload=True) +def _aten_max_with_other(self: TReal, other: TReal) -> TReal: + result = op.Max(self, other) + return result + + +@torch_op("aten::max", overload=True) +# def _aten_max_with_dim(self: TReal, dim: int, keepdim: bool) -> tuple[TReal, TInt]: +def _aten_max_with_dim(self: TReal, dim: int, keepdim: bool): + dims = op.Reshape(dim, op.Constant(value_int=[-1])) + result = op.ReduceMax(self, dims, keepdims=keepdim) + indices = op.ArgMax(self, axis=dim, keepdims=keepdim) + return result, indices def aten_max_pool1d( diff --git a/onnxscript/tests/function_libs/torch_aten/ops_correctness_test.py b/onnxscript/tests/function_libs/torch_aten/ops_correctness_test.py index aa26c20631..5df66b638f 100644 --- a/onnxscript/tests/function_libs/torch_aten/ops_correctness_test.py +++ b/onnxscript/tests/function_libs/torch_aten/ops_correctness_test.py @@ -398,6 +398,7 @@ def _where_input_wrangler( "convolution": core_ops.aten_convolution, "empty_like": core_ops.aten_empty_like, "index_select": core_ops.aten_index_select, + "max": core_ops.aten_max, "native_layer_norm": core_ops.aten_native_layer_norm, "new_empty": core_ops.aten_new_empty, "new_empty_strided": core_ops.aten_new_empty_strided,