-
Notifications
You must be signed in to change notification settings - Fork 71
feat(atenlib): add ops (any, expand_as) #463
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
fc3a767
2a0a781
e646207
a2ba6b3
fd38b6a
745aefb
a3338af
eab299a
299766f
3ea3adf
fb15d53
1732e45
77d3e66
d21a139
1f3646b
567d09b
bfb6c25
1795cf3
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -269,10 +269,32 @@ def aten_angle(self: TensorType) -> TensorType: | |
raise NotImplementedError() | ||
|
||
|
||
def aten_any(self: TensorType) -> TensorType: | ||
@torch_op("aten::any", trace_only=True) | ||
def aten_any(self: TTensor, dim: Optional[int] = None, keepdim: bool = True) -> BOOL: | ||
xiaowuhu marked this conversation as resolved.
Show resolved
Hide resolved
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think keepdim should default to False, or set as That way we can avoid hard coding There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Shouldn't the default value be determined by the spec of the aten op? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Looking at it again, I figured we are supporting 2 overloads of And for
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. cc @justinchuby it is weird the wrong default is not caught in op test, I wonder if it is just not enough coverage in original opinfo or something else. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think we should not have combined the two overloads, and yes - you have a great point that we should verify the opinfo to assert how much confidence it can provide us for each op before trusting it. I will add this to the authoring guide |
||
"""any(Tensor self) -> Tensor""" | ||
|
||
raise NotImplementedError() | ||
minus_1 = op.Constant(value_ints=[-1]) | ||
xiaowuhu marked this conversation as resolved.
Show resolved
Hide resolved
|
||
self_rank = op.Size(op.Shape(self)) | ||
if self_rank == 0: | ||
self = op.Reshape(self, minus_1) | ||
|
||
zero = op.Constant(value_float=0.0) | ||
result = op.Not(op.Equal(self, zero)) | ||
xiaowuhu marked this conversation as resolved.
Show resolved
Hide resolved
|
||
# because op.ReduceMax() cannot calculate BOOL value | ||
result_float = op.Cast(result, to=FLOAT.dtype) | ||
|
||
if op.OptionalHasElement(dim): | ||
dim = op.Reshape(dim, minus_1) | ||
dims = op.Cast(dim, to=INT64.dtype) | ||
result_max = op.ReduceMax(result_float, dims, keepdims=keepdim, noop_with_empty_axes=0) | ||
else: | ||
result_max = op.ReduceMax(result_float, keepdims=0, noop_with_empty_axes=0) | ||
|
||
result = op.Greater(result_max, zero) | ||
if self_rank == 0: | ||
result = op.Squeeze(result) | ||
|
||
return result | ||
|
||
|
||
def _range_supported(dtype: int) -> bool: | ||
|
@@ -2044,10 +2066,13 @@ def aten_expand(self: TTensor, size: TInt) -> TTensor: | |
return op.Expand(self, size) | ||
|
||
|
||
def aten_expand_as(self: TensorType, other: TensorType) -> TensorType: | ||
@torch_op("aten::expand_as") | ||
def aten_expand_as(self: TTensor, other: TTensor) -> TTensor: | ||
"""expand_as(Tensor(a) self, Tensor other) -> Tensor(a)""" | ||
|
||
raise NotImplementedError() | ||
shape = op.Shape(other) | ||
result = op.Expand(self, shape) | ||
return result | ||
|
||
|
||
def aten_expand_copy(self: TensorType, size: INT64, implicit: bool = False) -> TensorType: | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why do we need trace_only here? Looks like all of operations are safe.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
due to the Optional argument parsing.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Could you say more on what optional argument parsing is? I would also add a comment on why trace_only was needed in code
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
if not use trace_only, it will throw exception:
E TypeError: Required input 'dim: Attribute[None]' was not provide
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think this is the same as the issue we discussed in today's meeting.
if (OptionalHasElement(x)) then use(x)
is an issue. We need to find a solution to this. One possibility is to extend OptionalGetElement, but wonder if that's enough for all cases.