Skip to content

feat(atenlib): use typevars for functions #261

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 5 commits into from
Jan 3, 2023
Merged

Conversation

justinchuby
Copy link
Collaborator

@justinchuby justinchuby commented Dec 16, 2022

Stack from ghstack (oldest at bottom):

This change implements

TTensor = TypeVar("TTensor", bound=_TensorType)
TFloat = TypeVar("TFloat", bound=_FloatType)
TInt = TypeVar("TInt", bound=_IntType)
TReal = TypeVar("TReal", bound=_RealType)

and uses them to annotate the functions. This information will be used to generate symbolic functions which adds glue nodes and dispatches to the ONNX functions.

In order for type checking to work, the generated ops need to use TypeVars as well.

  • Also fix mypy ignore directives

justinchuby added a commit that referenced this pull request Dec 16, 2022
ghstack-source-id: 5779ed0
Pull Request resolved: #261
@justinchuby justinchuby requested a review from fatcat-z December 16, 2022 01:05
@justinchuby justinchuby changed the title feat(atenlib): use generic types for functions feat(atenlib): use typevars for functions Dec 16, 2022
@justinchuby justinchuby changed the base branch from gh/justinchuby/16/base to main December 16, 2022 01:05
@justinchuby justinchuby added the module: torchlib Related to the torch/aten function lib in development label Dec 16, 2022
@codecov
Copy link

codecov bot commented Dec 16, 2022

Codecov Report

Merging #261 (2412510) into main (b3ec873) will increase coverage by 0.12%.
The diff coverage is 94.52%.

@@            Coverage Diff             @@
##             main     #261      +/-   ##
==========================================
+ Coverage   72.27%   72.39%   +0.12%     
==========================================
  Files          94       95       +1     
  Lines        9007     9021      +14     
==========================================
+ Hits         6510     6531      +21     
+ Misses       2497     2490       -7     
Impacted Files Coverage Δ
onnxscript/irbuilder.py 79.92% <ø> (ø)
onnxscript/function_libs/torch_aten/ops/nn.py 52.47% <80.00%> (+0.18%) ⬆️
onnxscript/function_libs/torch_aten/ops/core.py 56.09% <94.91%> (+0.03%) ⬆️
onnxscript/function_libs/torch_aten/typing.py 100.00% <100.00%> (ø)
onnxscript/converter.py 92.03% <0.00%> (+0.80%) ⬆️
onnxscript/type_annotation.py 94.11% <0.00%> (+2.94%) ⬆️

Help us with your feedback. Take ten seconds to tell us how you rate us. Have a feature suggestion? Share it here.

@justinchuby justinchuby requested a review from abock December 16, 2022 01:09
@justinchuby
Copy link
Collaborator Author

@abock I think onnxscript is not complaining about the typing annotations now so we are good (hopefully?).

This change implements

```python
TTensor = TypeVar("TTensor", bound=_TensorType)
TFloat = TypeVar("TFloat", bound=_FloatType)
TInt = TypeVar("TInt", bound=_IntType)
TReal = TypeVar("TReal", bound=_RealType)
```

and uses them to annotate the functions. This information will be used to generate symbolic functions which adds glue nodes and dispatches to the ONNX functions.

In order for type checking to work, the generated ops need to use `TypeVar`s as well.

[ghstack-poisoned]
This change implements

```python
TTensor = TypeVar("TTensor", bound=_TensorType)
TFloat = TypeVar("TFloat", bound=_FloatType)
TInt = TypeVar("TInt", bound=_IntType)
TReal = TypeVar("TReal", bound=_RealType)
```

and uses them to annotate the functions. This information will be used to generate symbolic functions which adds glue nodes and dispatches to the ONNX functions.

In order for type checking to work, the generated ops need to use `TypeVar`s as well.

[ghstack-poisoned]
justinchuby added a commit that referenced this pull request Jan 3, 2023
ghstack-source-id: f1123c5
Pull Request resolved: #261
# abs(Tensor self) -> Tensor

return op.Abs(self)


@torch_op("aten::acos")
def aten_acos(self):
def aten_acos(self: TFloat) -> TFloat:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In current exporter implementation, actually even the input is tensor(bool) it still can work.
If we limit the input to TFloat like this, will it break export?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The function does not take bools because onnx Acos does not. The current plan is to use the generated symbolic functions to handle dtype conversion.

Copy link
Contributor

@fatcat-z fatcat-z left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Approved

@justinchuby justinchuby merged commit 078471e into main Jan 3, 2023
@justinchuby justinchuby deleted the gh/justinchuby/16/head branch January 3, 2023 15:56
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
module: torchlib Related to the torch/aten function lib in development
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants