-
Notifications
You must be signed in to change notification settings - Fork 64
feat(atenlib): implement aten functions 1/n #247
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
To allow mypy to analyze typing for annotated functions. Otherwise it complains that "Untyped decorator makes function "ones_like" untyped [misc]" [ghstack-poisoned]
[ghstack-poisoned]
Codecov Report
@@ Coverage Diff @@
## main #247 +/- ##
==========================================
+ Coverage 71.60% 71.68% +0.07%
==========================================
Files 93 93
Lines 8835 8835
==========================================
+ Hits 6326 6333 +7
+ Misses 2509 2502 -7
Help us with your feedback. Take ten seconds to tell us how you rate us. Have a feature suggestion? Share it here. |
def ones_like(x, onnx_dtype: int): | ||
shape = op.Shape(x) | ||
return op.ConstantOfShape( | ||
shape, value=onnx.helper.make_tensor("one", onnx_dtype, [1], [1]) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
One issue here is that ONNX attributes should be compile-time (that is, script-time) constants, and, hence, cannot depend on a parameter like onnx_dtype. Is the idea that output type should be the same as input type x
? Or, can it be a different dtype?
Anyway, I would suggest using either CastLike
or Cast
, depending on what exactly you want.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The only thing we can do with an attribute-parameter in ONNX is to forward it as the attribute-value of another call.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Here's one option:
def ones_like(x, dtype:int):
shape = op.Shape(x)
one = op.Constant(value_float=1.0)
one_dtype = op.Cast(one, to=dtype)
return op.Expand(one_dtype, shape)
I wrote it out in full to focus on the ONNX representation. In onnxscript, we can abbreviate it a little bit, as below:
def ones_like(x, dtype:int):
shape = op.Shape(x)
one_dtype = op.Cast(1, to=dtype) # Forwarding attribute-parameter as attribute-value is ok.
return op.Expand(one_dtype, shape)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The constraints come from various ONNX ops.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The other point I would like to understand is about the intended usage. In an exporter, for example: you could choose to
- add multiple ONNX ops (a subgraph) to the graph being constructed (representing a single pytorch op)
- add a single call to an ONNX function (op),
They are slightly different (for example, the first approach is more general). The first can be done using an approach we have discussed previously, something like below:
def specialized_ones_like(dtype: int):
# This is outside the script, so okay:
one_dtype = onnx.helper.make_tensor("one", dtype, [1], [1])
@script()
def f(x):
shape = op.Shape(x)
return op.ConstantOfShape(shape, value=one_dtype)
return f.to_function_proto().node # Depending on intended usage
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for the example! I love how clean it is. I am probably going to think more about the specialized version before using it requires the onnx function to be defined in a function scope to be able to close over dtype.
Question: f.to_function_proto().node
looks cool which I didn't know about. What is the node
we get here?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done
Implemented ops - lt - gt - round - clamp_min - clamp_max - clamp - repeat - ones_like [ghstack-poisoned]
Implemented ops - lt - gt - round - clamp_min - clamp_max - clamp - repeat - ones_like [ghstack-poisoned]
Implemented ops - lt - gt - round - clamp_min - clamp_max - clamp - repeat - ones_like [ghstack-poisoned]
# repeat(Tensor self, SymInt[] repeats) -> Tensor | ||
|
||
raise NotImplementedError() | ||
# FIXME(justinchuby): When repeats.shape == [0] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is an example where we need shape dependent logic.
Implemented ops - lt - gt - round - clamp_min - clamp_max - clamp - repeat - ones_like [ghstack-poisoned]
Implemented ops - lt - gt - round - clamp_min - clamp_max - clamp - repeat - ones_like [ghstack-poisoned]
@@ -747,16 +749,31 @@ def aten_clamp( | |||
raise NotImplementedError() | |||
|
|||
|
|||
def aten_clamp_max(self: TensorType, max: float) -> TensorType: | |||
def aten_clamp_max_scalar(self, max_): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is it possible to validate the shape of max_ to know it is a scalar or tensor instead of having such information in the function name?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I haven't thought of a good way. Any suggestions?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ah, for validation, we can specify it in the signature / via some data structure. We can discuss today.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
They still need to be two functions though.
Implemented ops - lt - gt - round - clamp_min - clamp_max - clamp - repeat - ones_like Test: - Create the ability to skip particular sub tests - Create a helper function to transform torch inputs into numpy for onnxscript to run on [ghstack-poisoned]
Implemented ops - lt - gt - round - clamp_min - clamp_max - clamp - repeat - ones_like Test: - Create the ability to skip particular sub tests - Create a helper function to transform torch inputs into numpy for onnxscript to run on [ghstack-poisoned]
return input.detach().cpu().numpy() | ||
if isinstance(input, (tuple, list)): | ||
if len(input) == 0: | ||
return np.array((), dtype=np.int64) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why is not the default value float32?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I found that it’s usually an index tuple so it needs to be int. If there are more subtle cases we can update the logic to handle those
Implemented ops - lt - gt - round - clamp_min - clamp_max - clamp - repeat - ones_like Test: - Create the ability to skip particular sub tests - Create a helper function to transform torch inputs into numpy for onnxscript to run on [ghstack-poisoned]
Implemented ops - lt - gt - round - clamp_min - clamp_max - clamp - repeat - ones_like Test: - Create the ability to skip particular sub tests - Create a helper function to transform torch inputs into numpy for onnxscript to run on [ghstack-poisoned]
Implemented ops - lt - gt - round - clamp_min - clamp_max - clamp - repeat - ones_like Test: - Create the ability to skip particular sub tests - Create a helper function to transform torch inputs into numpy for onnxscript to run on [ghstack-poisoned]
Implemented ops - lt - gt - round - clamp_min - clamp_max - clamp - repeat - ones_like Test: - Create the ability to skip particular sub tests - Create a helper function to transform torch inputs into numpy for onnxscript to run on [ghstack-poisoned]
Implemented ops - lt - gt - round - clamp_min - clamp_max - clamp - repeat - ones_like Test: - Create the ability to skip particular sub tests - Create a helper function to transform torch inputs into numpy for onnxscript to run on [ghstack-poisoned]
Implemented ops - lt - gt - round - clamp_min - clamp_max - clamp - repeat - ones_like Test: - Create the ability to skip particular sub tests - Create a helper function to transform torch inputs into numpy for onnxscript to run on [ghstack-poisoned]
Implemented ops - lt - gt - round - clamp_min - clamp_max - clamp - repeat - ones_like Test: - Create the ability to skip particular sub tests - Create a helper function to transform torch inputs into numpy for onnxscript to run on [ghstack-poisoned]
Implemented ops - lt - gt - round - clamp_min - clamp_max - clamp - repeat - ones_like Test: - Create the ability to skip particular sub tests - Create a helper function to transform torch inputs into numpy for onnxscript to run on [ghstack-poisoned]
Implemented ops - lt - gt - round - clamp_min - clamp_max - clamp - repeat - ones_like Test: - Create the ability to skip particular sub tests - Create a helper function to transform torch inputs into numpy for onnxscript to run on [ghstack-poisoned]
Implemented ops - lt - gt - round - clamp_min - clamp_max - clamp - repeat - ones_like Test: - Create the ability to skip particular sub tests - Create a helper function to transform torch inputs into numpy for onnxscript to run on [ghstack-poisoned]
Stack from [ghstack](https://github.com/ezyang/ghstack) (oldest at bottom): * #247 * __->__ #248 Also added annotations to `script()`, to allow mypy to analyze typing for annotated functions. Otherwise it complains that "Untyped decorator makes function "ones_like" untyped [misc]"
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM, thanks!
Implemented ops - lt - gt - round - clamp_min - clamp_max - clamp - repeat - ones_like Test: - Create the ability to skip particular sub tests - Create a helper function to transform torch inputs into numpy for onnxscript to run on [ghstack-poisoned]
Implemented ops - lt - gt - round - clamp_min - clamp_max - clamp - repeat - ones_like Test: - Create the ability to skip particular sub tests - Create a helper function to transform torch inputs into numpy for onnxscript to run on [ghstack-poisoned]
Stack from ghstack (oldest at bottom):
Implemented ops
Test: