Skip to content

feat(atenlib): add ops (any, expand_as) #463

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 18 commits into from
Mar 2, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
33 changes: 29 additions & 4 deletions onnxscript/function_libs/torch_aten/ops/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -269,10 +269,32 @@ def aten_angle(self: TensorType) -> TensorType:
raise NotImplementedError()


def aten_any(self: TensorType) -> TensorType:
@torch_op("aten::any", trace_only=True)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why do we need trace_only here? Looks like all of operations are safe.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

due to the Optional argument parsing.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you say more on what optional argument parsing is? I would also add a comment on why trace_only was needed in code

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

if not use trace_only, it will throw exception:


E   TypeError: Required input 'dim: Attribute[None]' was not provide

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this is the same as the issue we discussed in today's meeting. if (OptionalHasElement(x)) then use(x) is an issue. We need to find a solution to this. One possibility is to extend OptionalGetElement, but wonder if that's enough for all cases.

def aten_any(self: TTensor, dim: Optional[int] = None, keepdim: bool = True) -> BOOL:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think keepdim should default to False, or set as keepdim: Optional[bool] = None, and explicitly set to False if None.

That way we can avoid hard coding keepdims=0 later.

https://github.com/microsoft/onnx-script/blob/ec7c26471e0ef1896ba492549f33d0367b0301e4/onnxscript/function_libs/torch_aten/ops/core.py#L292

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Shouldn't the default value be determined by the spec of the aten op?

Copy link
Contributor

@BowenBao BowenBao Mar 4, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looking at it again, I figured we are supporting 2 overloads of aten::any with this function. So for - func: any(Tensor self) -> Tensor we need actually figuring out what the behavior is since both dim and keepdim are not provided and have no defaults. So I think this is fine.

And for - func: any.dim(Tensor self, int dim, bool keepdim=False) -> Tensor, yea you are right, the current default is wrong and it should be False.

- func: any(Tensor self) -> Tensor
  device_check: NoCheck   # TensorIterator
  structured_delegate: any.all_out
  variants: method, function
  dispatch:
    SparseCPU, SparseCUDA: any_sparse

- func: any.dim(Tensor self, int dim, bool keepdim=False) -> Tensor
  device_check: NoCheck   # TensorIterator
  structured_delegate: any.out
  variants: function, method

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

cc @justinchuby it is weird the wrong default is not caught in op test, I wonder if it is just not enough coverage in original opinfo or something else.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we should not have combined the two overloads, and yes - you have a great point that we should verify the opinfo to assert how much confidence it can provide us for each op before trusting it. I will add this to the authoring guide

"""any(Tensor self) -> Tensor"""

raise NotImplementedError()
minus_1 = op.Constant(value_ints=[-1])
self_rank = op.Size(op.Shape(self))
if self_rank == 0:
self = op.Reshape(self, minus_1)

zero = op.Constant(value_float=0.0)
result = op.Not(op.Equal(self, zero))
# because op.ReduceMax() cannot calculate BOOL value
result_float = op.Cast(result, to=FLOAT.dtype)

if op.OptionalHasElement(dim):
dim = op.Reshape(dim, minus_1)
dims = op.Cast(dim, to=INT64.dtype)
result_max = op.ReduceMax(result_float, dims, keepdims=keepdim, noop_with_empty_axes=0)
else:
result_max = op.ReduceMax(result_float, keepdims=0, noop_with_empty_axes=0)

result = op.Greater(result_max, zero)
if self_rank == 0:
result = op.Squeeze(result)

return result


def _range_supported(dtype: int) -> bool:
Expand Down Expand Up @@ -2044,10 +2066,13 @@ def aten_expand(self: TTensor, size: TInt) -> TTensor:
return op.Expand(self, size)


def aten_expand_as(self: TensorType, other: TensorType) -> TensorType:
@torch_op("aten::expand_as")
def aten_expand_as(self: TTensor, other: TTensor) -> TTensor:
"""expand_as(Tensor(a) self, Tensor other) -> Tensor(a)"""

raise NotImplementedError()
shape = op.Shape(other)
result = op.Expand(self, shape)
return result


def aten_expand_copy(self: TensorType, size: INT64, implicit: bool = False) -> TensorType:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -303,6 +303,7 @@ def _where_input_wrangler(
"exp": core_ops.aten_exp,
"exp2": core_ops.aten_exp2,
"expand": core_ops.aten_expand,
"expand_as": core_ops.aten_expand_as,
"erf": core_ops.aten_erf,
"fill": core_ops.aten_fill,
"fmod": core_ops.aten_fmod,
Expand Down Expand Up @@ -387,6 +388,7 @@ def _where_input_wrangler(
] = {
"amax": core_ops.aten_amax,
"amin": core_ops.aten_amin,
"any": core_ops.aten_any, # TODO: add more testcase which element is [0.0, 0.1, -0.1, 0.0] etc.
"arange_start_step": core_ops.aten_arange_start_step,
"arange_start": core_ops.aten_arange_start,
"arange": core_ops.aten_arange,
Expand Down