Skip to content

feat(atenlib): add op(max) #314

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 32 commits into from
Mar 1, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
32 commits
Select commit Hold shift + click to select a range
362d800
Update core.py
xiaowuhu Jan 11, 2023
bb706ba
test
xiaowuhu Jan 11, 2023
5eddec2
Update core.py
xiaowuhu Jan 12, 2023
82b212a
Update core.py
xiaowuhu Jan 12, 2023
cbda011
Update core.py
xiaowuhu Jan 12, 2023
b4686f4
update
xiaowuhu Jan 12, 2023
a212b26
Merge remote-tracking branch 'upstream/main' into xiaowu/addOp(Mean)
xiaowuhu Jan 12, 2023
575dbd8
Update core.py
xiaowuhu Jan 12, 2023
dcb7344
Merge remote-tracking branch 'upstream/main' into xiaowu/addOp(Mean)
xiaowuhu Jan 13, 2023
9041d19
update
xiaowuhu Jan 13, 2023
be39312
Update core.py
xiaowuhu Jan 19, 2023
497785c
Merge branch 'main' into pr/314
xiaowuhu Jan 19, 2023
cae8f2b
Merge remote-tracking branch 'upstream/main' into xiaowu/addOp(Mean)
xiaowuhu Feb 9, 2023
e500df9
update
xiaowuhu Feb 9, 2023
d310501
update
xiaowuhu Feb 10, 2023
b7108d5
Merge remote-tracking branch 'upstream/main' into xiaowu/addOp(Mean)
xiaowuhu Feb 28, 2023
d793af7
update
xiaowuhu Feb 28, 2023
f04a9c8
update
xiaowuhu Feb 28, 2023
298fa8f
Update core.py
xiaowuhu Feb 28, 2023
48c9b40
Update evaluator.py
xiaowuhu Feb 28, 2023
ce5ccf1
Merge remote-tracking branch 'upstream/main' into xiaowu/addOp(Mean)
xiaowuhu Feb 28, 2023
790f533
Update ops_correctness_test.py
xiaowuhu Feb 28, 2023
187b8a7
Update ops_correctness_test.py
xiaowuhu Feb 28, 2023
d81e75a
Update core.py
xiaowuhu Feb 28, 2023
8a29bcf
Update core.py
xiaowuhu Feb 28, 2023
d90444a
Merge branch 'main' into xiaowu/addOp(Mean)
xiaowuhu Mar 1, 2023
36aad00
Update core.py
xiaowuhu Mar 1, 2023
5f51c2c
Merge branch 'xiaowu/addOp(Mean)' of https://github.com/xiaowuhu/onnx…
xiaowuhu Mar 1, 2023
8276c58
Update core.py
xiaowuhu Mar 1, 2023
be968c0
update
xiaowuhu Mar 1, 2023
2258ed8
Merge remote-tracking branch 'upstream/main' into xiaowu/addOp(Mean)
xiaowuhu Mar 1, 2023
0daa0a1
Update core.py
xiaowuhu Mar 1, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
52 changes: 50 additions & 2 deletions onnxscript/function_libs/torch_aten/ops/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -3257,10 +3257,58 @@ def aten_matrix_power(self: TensorType, n: int) -> TensorType:
raise NotImplementedError()


def aten_max(self: TensorType) -> TensorType:
@torch_op("aten::max", trace_only=True)
def aten_max(
self: TReal, dim_or_other: Union[TReal, INT64] = None, keepdim: BOOL = None
) -> TReal:
"""max(Tensor self) -> Tensor"""

raise NotImplementedError()
self_rank = op.Size(op.Shape(self))
if self_rank == 0:
self = op.Reshape(self, op.Constant(value_int=[-1]))

output = 1

if op.OptionalHasElement(dim_or_other):
if isinstance(dim_or_other, int):
if not op.OptionalHasElement(keepdim):
keepdim = False
result, indices = _aten_max_with_dim(self, dim_or_other, keepdim)
output = 2
else: # dim_or_other is tensor
result = _aten_max_with_other(self, dim_or_other)
else:
result = _aten_max_with_no_dim(self)

if self_rank == 0:
result = op.Squeeze(result)

if output == 2:
if self_rank == 0:
indices = op.Squeeze(indices) # type: ignore[has-type]
return result, indices
return result


@torch_op("aten::max", overload=True)
def _aten_max_with_no_dim(self: TReal) -> TReal:
result = op.ReduceMax(self, keepdims=0)
return result


@torch_op("aten::max", overload=True)
def _aten_max_with_other(self: TReal, other: TReal) -> TReal:
result = op.Max(self, other)
return result


@torch_op("aten::max", overload=True)
# def _aten_max_with_dim(self: TReal, dim: int, keepdim: bool) -> tuple[TReal, TInt]:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: Dead code

def _aten_max_with_dim(self: TReal, dim: int, keepdim: bool):
dims = op.Reshape(dim, op.Constant(value_int=[-1]))
result = op.ReduceMax(self, dims, keepdims=keepdim)
indices = op.ArgMax(self, axis=dim, keepdims=keepdim)
return result, indices


def aten_max_pool1d(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -398,6 +398,7 @@ def _where_input_wrangler(
"convolution": core_ops.aten_convolution,
"empty_like": core_ops.aten_empty_like,
"index_select": core_ops.aten_index_select,
"max": core_ops.aten_max,
"native_layer_norm": core_ops.aten_native_layer_norm,
"new_empty": core_ops.aten_new_empty,
"new_empty_strided": core_ops.aten_new_empty_strided,
Expand Down