-
Notifications
You must be signed in to change notification settings - Fork 63
Add ops(unflatten) | feat(atenlib) #612
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add ops(unflatten) | feat(atenlib) #612
Conversation
Codecov Report
@@ Coverage Diff @@
## main #612 +/- ##
==========================================
+ Coverage 74.35% 74.38% +0.03%
==========================================
Files 107 107
Lines 11209 11226 +17
Branches 1161 1162 +1
==========================================
+ Hits 8334 8351 +17
Misses 2560 2560
Partials 315 315
Help us with your feedback. Take ten seconds to tell us how you rate us. Have a feature suggestion? Share it here. |
dim = self_rank + dim | ||
|
||
head_start_idx = op.Constant(value_ints=[0]) | ||
head_end_idx = op.Reshape(dim, op.Constant(value_ints=[1])) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@gramalingam @justinchuby
ORTEvaluator doesn't seem to understand:
head_end_idx = op.Reshape(dim, [1]))
skip( | ||
"unflatten", | ||
matcher=lambda sample: any(dim == 0 for dim in sample.input.shape), | ||
reason="0 dim in ONNX is undefined behavior.", | ||
), |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can we create a condition to handle this case?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
To raise a warning? Not sure what we should give back to users.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
What is pytorch's behavior?
To raise a warning? Not sure what we should give back to users.
I was suggesting handling it like pytorch does if possible
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
PyTorch takes this because 0-d is in their spec.
For example,
a = torch.randn(4, 0)
# tensor([], size=(4, 0))
torch.unflatten(a, 0, (2, 2))
# tensor([], size=(2, 2, 0))
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I wonder if torch_libs has a certain way to inform users when the difference between specs is hit? In TorchScript exporter, we could raise UnsupportedONNXError
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I see. We could skip this if it is too complicated and come back to it as needed. I would update the skip reason to say something to the effect of "fixme: logic not implemented for size 0 inputs"
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We don't raise errors in atenlib. The exporter can raise errors when it sees a difference after OpSchema is implemented
@@ -100,9 +100,6 @@ def aten_addmm( | |||
) -> TFloat: | |||
"""addmm(Tensor self, Tensor mat1, Tensor mat2, *, Scalar beta=1, Scalar alpha=1) -> Tensor""" | |||
|
|||
# TODO(titaiwang): op.Gemm seems needed to take care of corner case according to old symbolic_fn. | |||
# Currently, it shows op-level validation failing on bloom. | |||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Remove line 110 down below as well?
…nnx-script into titaiwang/add_unflatten
Fixes #599
xfail graph test to wait microsoft/onnxruntime#15409