Skip to content

Commit 58173d6

Browse files
committed
Merge remote-tracking branch 'upstream/main' into xiaowu/addOps(layer_norm)
2 parents 0570b92 + c5ca05b commit 58173d6

File tree

2 files changed

+51
-2
lines changed

2 files changed

+51
-2
lines changed

onnxscript/function_libs/torch_aten/ops/core.py

Lines changed: 50 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3267,10 +3267,58 @@ def aten_matrix_power(self: TensorType, n: int) -> TensorType:
32673267
raise NotImplementedError()
32683268

32693269

3270-
def aten_max(self: TensorType) -> TensorType:
3270+
@torch_op("aten::max", trace_only=True)
3271+
def aten_max(
3272+
self: TReal, dim_or_other: Union[TReal, INT64] = None, keepdim: BOOL = None
3273+
) -> TReal:
32713274
"""max(Tensor self) -> Tensor"""
32723275

3273-
raise NotImplementedError()
3276+
self_rank = op.Size(op.Shape(self))
3277+
if self_rank == 0:
3278+
self = op.Reshape(self, op.Constant(value_int=[-1]))
3279+
3280+
output = 1
3281+
3282+
if op.OptionalHasElement(dim_or_other):
3283+
if isinstance(dim_or_other, int):
3284+
if not op.OptionalHasElement(keepdim):
3285+
keepdim = False
3286+
result, indices = _aten_max_with_dim(self, dim_or_other, keepdim)
3287+
output = 2
3288+
else: # dim_or_other is tensor
3289+
result = _aten_max_with_other(self, dim_or_other)
3290+
else:
3291+
result = _aten_max_with_no_dim(self)
3292+
3293+
if self_rank == 0:
3294+
result = op.Squeeze(result)
3295+
3296+
if output == 2:
3297+
if self_rank == 0:
3298+
indices = op.Squeeze(indices) # type: ignore[has-type]
3299+
return result, indices
3300+
return result
3301+
3302+
3303+
@torch_op("aten::max", overload=True)
3304+
def _aten_max_with_no_dim(self: TReal) -> TReal:
3305+
result = op.ReduceMax(self, keepdims=0)
3306+
return result
3307+
3308+
3309+
@torch_op("aten::max", overload=True)
3310+
def _aten_max_with_other(self: TReal, other: TReal) -> TReal:
3311+
result = op.Max(self, other)
3312+
return result
3313+
3314+
3315+
@torch_op("aten::max", overload=True)
3316+
# def _aten_max_with_dim(self: TReal, dim: int, keepdim: bool) -> tuple[TReal, TInt]:
3317+
def _aten_max_with_dim(self: TReal, dim: int, keepdim: bool):
3318+
dims = op.Reshape(dim, op.Constant(value_int=[-1]))
3319+
result = op.ReduceMax(self, dims, keepdims=keepdim)
3320+
indices = op.ArgMax(self, axis=dim, keepdims=keepdim)
3321+
return result, indices
32743322

32753323

32763324
def aten_max_pool1d(

onnxscript/tests/function_libs/torch_aten/ops_correctness_test.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -399,6 +399,7 @@ def _where_input_wrangler(
399399
"empty_like": core_ops.aten_empty_like,
400400
"index_select": core_ops.aten_index_select,
401401
"layer_norm": core_ops.aten_layer_norm,
402+
"max": core_ops.aten_max,
402403
"native_layer_norm": core_ops.aten_native_layer_norm,
403404
"new_empty": core_ops.aten_new_empty,
404405
"new_empty_strided": core_ops.aten_new_empty_strided,

0 commit comments

Comments
 (0)