-
Notifications
You must be signed in to change notification settings - Fork 64
AddOp(linalg_vector_norm) | feat(torchlib) #908
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
Codecov Report
@@ Coverage Diff @@
## main #908 +/- ##
==========================================
- Coverage 76.74% 76.62% -0.13%
==========================================
Files 112 112
Lines 13486 13547 +61
Branches 1363 1377 +14
==========================================
+ Hits 10350 10380 +30
- Misses 2796 2823 +27
- Partials 340 344 +4
|
).skip( | ||
matcher=lambda sample: sample.kwargs.get("ord") == 6 | ||
and sample.input.dtype == torch.float16, | ||
reason="ORT return wrong value for float16 with ord=6 (expected=Inf, actual=9.48).", |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks like the number we get is actually more accurate. Is that true?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We can make it an xfail here potentially, so we know it when things get fixed
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If ORT is more accurate, we can change the word "wrong value".
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done
"""linalg_vector_norm(Tensor self, Scalar ord=2, int[1]? dim=None, bool keepdim=False, *, ScalarType? dtype=None) -> Tensor""" | ||
|
||
raise NotImplementedError() | ||
if dtype != -1: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@justinchuby I am confused. Do we want the diaptcher to take care of dtype overload?
The dtype scenario should follow aten_argmax or not?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We do. But there's also the dim: Optional[int] = None,
which can be split, in which case we will have 4 variants of the same function. Since the function is currently trace_only not because of the dtype situation here, I thought it's not worth it having 4 copies of this thing
I took the liberty to merge. Happy to follow up if there's more discuss desired |
Also updated the should_skip_xfail logic in test to account for data types.
Notes
Have to skip one kind of test case: ord=6, dtype=float16.
In Pytorch:
But in ORT, the result is:
the second element
9.08
should beinf
. It works in eager mode, failed in FullGraph mode only.