-
Notifications
You must be signed in to change notification settings - Fork 63
feat(atenlib): use typevars for functions #261
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
[ghstack-poisoned]
Codecov Report
@@ Coverage Diff @@
## main #261 +/- ##
==========================================
+ Coverage 72.27% 72.39% +0.12%
==========================================
Files 94 95 +1
Lines 9007 9021 +14
==========================================
+ Hits 6510 6531 +21
+ Misses 2497 2490 -7
Help us with your feedback. Take ten seconds to tell us how you rate us. Have a feature suggestion? Share it here. |
@abock I think onnxscript is not complaining about the typing annotations now so we are good (hopefully?). |
This change implements ```python TTensor = TypeVar("TTensor", bound=_TensorType) TFloat = TypeVar("TFloat", bound=_FloatType) TInt = TypeVar("TInt", bound=_IntType) TReal = TypeVar("TReal", bound=_RealType) ``` and uses them to annotate the functions. This information will be used to generate symbolic functions which adds glue nodes and dispatches to the ONNX functions. In order for type checking to work, the generated ops need to use `TypeVar`s as well. [ghstack-poisoned]
This change implements ```python TTensor = TypeVar("TTensor", bound=_TensorType) TFloat = TypeVar("TFloat", bound=_FloatType) TInt = TypeVar("TInt", bound=_IntType) TReal = TypeVar("TReal", bound=_RealType) ``` and uses them to annotate the functions. This information will be used to generate symbolic functions which adds glue nodes and dispatches to the ONNX functions. In order for type checking to work, the generated ops need to use `TypeVar`s as well. [ghstack-poisoned]
# abs(Tensor self) -> Tensor | ||
|
||
return op.Abs(self) | ||
|
||
|
||
@torch_op("aten::acos") | ||
def aten_acos(self): | ||
def aten_acos(self: TFloat) -> TFloat: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
In current exporter implementation, actually even the input is tensor(bool) it still can work.
If we limit the input to TFloat like this, will it break export?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The function does not take bools because onnx Acos does not. The current plan is to use the generated symbolic functions to handle dtype conversion.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Approved
Stack from ghstack (oldest at bottom):
This change implements
and uses them to annotate the functions. This information will be used to generate symbolic functions which adds glue nodes and dispatches to the ONNX functions.
In order for type checking to work, the generated ops need to use
TypeVar
s as well.