Skip to content

Add ops(unflatten) | feat(atenlib) #612

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 8 commits into from
Apr 7, 2023

Conversation

titaiwangms
Copy link
Contributor

Fixes #599
xfail graph test to wait microsoft/onnxruntime#15409

@titaiwangms titaiwangms added the module: torchlib Related to the torch/aten function lib in development label Apr 7, 2023
@codecov
Copy link

codecov bot commented Apr 7, 2023

Codecov Report

Merging #612 (ca464d2) into main (25027f9) will increase coverage by 0.03%.
The diff coverage is 100.00%.

@@            Coverage Diff             @@
##             main     #612      +/-   ##
==========================================
+ Coverage   74.35%   74.38%   +0.03%     
==========================================
  Files         107      107              
  Lines       11209    11226      +17     
  Branches     1161     1162       +1     
==========================================
+ Hits         8334     8351      +17     
  Misses       2560     2560              
  Partials      315      315              
Impacted Files Coverage Δ
onnxscript/function_libs/torch_aten/ops/core.py 73.87% <100.00%> (+0.17%) ⬆️
...s/function_libs/torch_aten/ops_correctness_test.py 88.43% <100.00%> (+0.08%) ⬆️

Help us with your feedback. Take ten seconds to tell us how you rate us. Have a feature suggestion? Share it here.

dim = self_rank + dim

head_start_idx = op.Constant(value_ints=[0])
head_end_idx = op.Reshape(dim, op.Constant(value_ints=[1]))
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@gramalingam @justinchuby
ORTEvaluator doesn't seem to understand:

head_end_idx = op.Reshape(dim, [1]))

Comment on lines 1079 to 1083
skip(
"unflatten",
matcher=lambda sample: any(dim == 0 for dim in sample.input.shape),
reason="0 dim in ONNX is undefined behavior.",
),
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we create a condition to handle this case?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

To raise a warning? Not sure what we should give back to users.

Copy link
Collaborator

@justinchuby justinchuby Apr 7, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What is pytorch's behavior?

To raise a warning? Not sure what we should give back to users.

I was suggesting handling it like pytorch does if possible

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

PyTorch takes this because 0-d is in their spec.
For example,

a = torch.randn(4, 0)
# tensor([], size=(4, 0))

torch.unflatten(a, 0, (2, 2))
# tensor([], size=(2, 2, 0))

Copy link
Contributor Author

@titaiwangms titaiwangms Apr 7, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I wonder if torch_libs has a certain way to inform users when the difference between specs is hit? In TorchScript exporter, we could raise UnsupportedONNXError

Copy link
Collaborator

@justinchuby justinchuby Apr 7, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I see. We could skip this if it is too complicated and come back to it as needed. I would update the skip reason to say something to the effect of "fixme: logic not implemented for size 0 inputs"

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We don't raise errors in atenlib. The exporter can raise errors when it sees a difference after OpSchema is implemented

@@ -100,9 +100,6 @@ def aten_addmm(
) -> TFloat:
"""addmm(Tensor self, Tensor mat1, Tensor mat2, *, Scalar beta=1, Scalar alpha=1) -> Tensor"""

# TODO(titaiwang): op.Gemm seems needed to take care of corner case according to old symbolic_fn.
# Currently, it shows op-level validation failing on bloom.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Remove line 110 down below as well?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
module: torchlib Related to the torch/aten function lib in development
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Support torch.unflatten
2 participants