Skip to content

aten::var implementation and aten::roll complex support #1186

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 12 commits into from
Nov 30, 2023

Conversation

luisfmnunes
Copy link
Contributor

@luisfmnunes luisfmnunes commented Nov 27, 2023

As mentioned on #1173, I'm trying to add aten::var and aten::roll (complex support) in order to export one model from PyTorch to ONNX. The model uses fft functions, which requires opset 18 and torch dynamo usage.

fixes #1175 fixes #1174

Copy link

codecov bot commented Nov 27, 2023

Codecov Report

Attention: 3 lines in your changes are missing coverage. Please review.

Comparison is base (5ba7efa) 78.56% compared to head (aa82989) 78.63%.

Files Patch % Lines
onnxscript/function_libs/torch_lib/ops/core.py 95.45% 1 Missing and 2 partials ⚠️
Additional details and impacted files
@@            Coverage Diff             @@
##             main    #1186      +/-   ##
==========================================
+ Coverage   78.56%   78.63%   +0.06%     
==========================================
  Files         118      118              
  Lines       15380    15445      +65     
  Branches     2408     2424      +16     
==========================================
+ Hits        12084    12145      +61     
- Misses       2900     2902       +2     
- Partials      396      398       +2     

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

@justinchuby justinchuby added the module: torchlib Related to the torch/aten function lib in development label Nov 27, 2023
@luisfmnunes
Copy link
Contributor Author

@microsoft-github-policy-service agree

Copy link
Collaborator

@justinchuby justinchuby left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks a lot!

@justinchuby justinchuby self-requested a review November 29, 2023 00:14
Comment on lines 6948 to 6954
for i in range(len(shifts)): # pylint: disable=consider-using-enumerate
shift = op.Gather(shifts, i, axis=0)
dim = dims[i]
self_real = _aten_roll_shift_and_dim_onnx(self_real, shift, dim)
self_imag = _aten_roll_shift_and_dim_onnx(self_imag, shift, dim)

result = op.Concat(self_real, self_imag, axis=-1)
Copy link
Collaborator

@justinchuby justinchuby Nov 29, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Since shifts is a tensor, it will not have len defined on it. Iterate on dims?

for i, dim in enumerate(dims):
   ...

) -> TReal:
"""var(Tensor self, int[1]? dim, bool unbiased=True, bool keepdim=False) -> Tensor"""

if isinstance(dim, int):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

dim could be something else? Should we annotate it on function sig?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I can follow up on this after #1192. I believe this code is just copied from var_mean

@justinchuby justinchuby merged commit 82d2063 into microsoft:main Nov 30, 2023
@justinchuby
Copy link
Collaborator

Will fix the dtype error in a separate PR.

justinchuby added a commit that referenced this pull request Nov 30, 2023
Follow up of #1186 to fix the dtype mismatch in Sub in the
`var`/`var.correction` implementation.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
module: torchlib Related to the torch/aten function lib in development
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Implement aten.roll for complex inputs Implement aten.var.correction
3 participants