Skip to content

AddOp(linalg_vector_norm) | feat(torchlib) #908

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 7 commits into from
Jul 25, 2023

Conversation

xiaowuhu
Copy link
Contributor

@xiaowuhu xiaowuhu commented Jul 23, 2023

Also updated the should_skip_xfail logic in test to account for data types.

Notes

Have to skip one kind of test case: ord=6, dtype=float16.
In Pytorch:

>>> b
tensor([[2.3730, 0.9316, 0.6240, 6.1523, 0.1758],
        [5.3984, 7.9375, 4.9062, 1.8809, 6.1016],
        [4.0234, 8.2344, 3.2695, 0.8701, 1.3447]], dtype=torch.float16)
>>> la.vector_norm(a, dim=0, ord=6)
tensor([5.5483, 9.0838, 4.9754, 6.1532, 6.1017])
>>> la.vector_norm(b, dim=0, ord=6)
tensor([5.5469,    inf, 4.9727, 6.1523, 6.0977], dtype=torch.float16)
>>>

But in ORT, the result is:

tensor([5.5469,  9.08, 4.9727, 6.1523, 6.0977], dtype=loat16)

the second element 9.08 should be inf. It works in eager mode, failed in FullGraph mode only.

@codecov
Copy link

codecov bot commented Jul 23, 2023

Codecov Report

Merging #908 (4ca6e8d) into main (f51abd2) will decrease coverage by 0.13%.
The diff coverage is 90.47%.

@@            Coverage Diff             @@
##             main     #908      +/-   ##
==========================================
- Coverage   76.74%   76.62%   -0.13%     
==========================================
  Files         112      112              
  Lines       13486    13547      +61     
  Branches     1363     1377      +14     
==========================================
+ Hits        10350    10380      +30     
- Misses       2796     2823      +27     
- Partials      340      344       +4     
Files Changed Coverage Δ
onnxscript/function_libs/torch_lib/ops/linalg.py 67.14% <89.09%> (+14.81%) ⬆️
...nxscript/tests/function_libs/torch_lib/ops_test.py 94.57% <100.00%> (+0.08%) ⬆️
...ipt/tests/function_libs/torch_lib/ops_test_data.py 96.87% <100.00%> (+0.07%) ⬆️

... and 7 files with indirect coverage changes

@xiaowuhu xiaowuhu mentioned this pull request Jul 23, 2023
@titaiwangms titaiwangms added the module: torchlib Related to the torch/aten function lib in development label Jul 23, 2023
).skip(
matcher=lambda sample: sample.kwargs.get("ord") == 6
and sample.input.dtype == torch.float16,
reason="ORT return wrong value for float16 with ord=6 (expected=Inf, actual=9.48).",
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks like the number we get is actually more accurate. Is that true?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We can make it an xfail here potentially, so we know it when things get fixed

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If ORT is more accurate, we can change the word "wrong value".

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done

"""linalg_vector_norm(Tensor self, Scalar ord=2, int[1]? dim=None, bool keepdim=False, *, ScalarType? dtype=None) -> Tensor"""

raise NotImplementedError()
if dtype != -1:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@justinchuby I am confused. Do we want the diaptcher to take care of dtype overload?

The dtype scenario should follow aten_argmax or not?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We do. But there's also the dim: Optional[int] = None, which can be split, in which case we will have 4 variants of the same function. Since the function is currently trace_only not because of the dtype situation here, I thought it's not worth it having 4 copies of this thing

@justinchuby justinchuby changed the title AddOp(linalg vector norm) | feat(torchlib) AddOp(linalg_vector_norm) | feat(torchlib) Jul 25, 2023
@justinchuby justinchuby merged commit 241260f into main Jul 25, 2023
@justinchuby justinchuby deleted the xiaowu/AddOp(linalg_vector_norm) branch July 25, 2023 18:39
@justinchuby
Copy link
Collaborator

I took the liberty to merge. Happy to follow up if there's more discuss desired

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
module: torchlib Related to the torch/aten function lib in development
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants