Skip to content

Add Ops(_native_batch_norm_legit_functional) | feat(torchlib) #1143

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged

Conversation

titaiwangms
Copy link
Contributor

@titaiwangms titaiwangms commented Nov 10, 2023

Fix #1140

Add
(1) aten::_native_batch_norm_legit.no_stats
(2) aten::_to_copy
(3) aten::_native_batch_norm_legit_functional

aten::_native_batch_norm_legit_functional is only invoked by Functionalization pass, so it can't be tested in op_test. It will be added into op_test in converter side. The only difference btween the op and aten::_native_batch_norm_legit is the output numbers. aten::_native_batch_norm_legit_functional returns running_mean and running_var according to https://github.com/pytorch/pytorch/blob/1488bafb274fcc82c8aac429bad61738bc3f950e/torch/_decomp/decompositions.py#L1804-L1826

aten_native_batch_norm_legit is split into two sample inputs to separately feed into different ONNX variants, since they require different set of arguments.

@titaiwangms titaiwangms added the module: torchlib Related to the torch/aten function lib in development label Nov 10, 2023
@titaiwangms titaiwangms marked this pull request as draft November 10, 2023 01:42
Copy link

codecov bot commented Nov 10, 2023

Codecov Report

Merging #1143 (4dc250f) into main (fdef96c) will decrease coverage by 0.01%.
The diff coverage is 78.78%.

@@            Coverage Diff             @@
##             main    #1143      +/-   ##
==========================================
- Coverage   78.41%   78.41%   -0.01%     
==========================================
  Files         118      118              
  Lines       15062    15125      +63     
  Branches     1607     1618      +11     
==========================================
+ Hits        11811    11860      +49     
- Misses       2880     2889       +9     
- Partials      371      376       +5     
Files Coverage Δ
...nxscript/tests/function_libs/torch_lib/ops_test.py 94.81% <ø> (ø)
...ipt/tests/function_libs/torch_lib/ops_test_data.py 96.18% <ø> (ø)
...ript/tests/function_libs/torch_lib/extra_opinfo.py 97.06% <90.90%> (-0.33%) ⬇️
onnxscript/function_libs/torch_lib/ops/core.py 79.66% <72.72%> (-0.12%) ⬇️

@titaiwangms titaiwangms marked this pull request as ready for review November 10, 2023 01:49
@justinchuby justinchuby self-assigned this Nov 10, 2023
@justinchuby
Copy link
Collaborator

justinchuby commented Nov 10, 2023

Is it possible to create a custom op info for _native_batch_norm_legit_functional?

@titaiwangms
Copy link
Contributor Author

Is it possible to create a custom op info for _native_batch_norm_legit_functional?

Because it's not in https://github.com/pytorch/pytorch/blob/main/aten/src/ATen/native/native_functions.yaml, I thought it couldn't be invoked with torch.ops.aten. I added the tests.

Comment on lines +5579 to +5580
# NOTE: This op is invoked by PyTorch Functionalization, and not in
# native_functions.yaml, It can be found in torch/_decomp/decompositions.py
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should we file an issue in pytorch to include it in the yaml?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

bias = op.Expand(op.Constant(value_floats=[0.0]), op.Shape(input, start=1, end=2))

axes = list(range(len(input.shape)))
axes.pop(1)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Possible to add a comment on why we pop index 1?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

hmm this code is duplicated from aten_native_batch_norm. They are only different from the output numbers. Maybe @xiaowuhu can chime in and answer this?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Might be confusing I am duplicating the code. Alternatively, we could add a higher level of traced function to merge these ops together, and use functional: bool to differentiate them.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Feel free to add a todo for now and create an issue to track

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Might be confusing I am duplicating the code. Alternatively, we could add a higher level of traced function to merge these ops together, and use functional: bool to differentiate them.

I think it's fine for now.

# Cannot return 2 dup output, so have to do twice with different variable name
empty_mean = op.CastLike(op.Shape(input, start=0, end=0), norm)
empty_var = op.CastLike(op.Shape(input, start=0, end=0), norm)
return norm, empty_mean, empty_var, running_mean, running_var
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Add a note saying we omitted computing mean and var so readers know to implement them when needed?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done

Copy link
Collaborator

@justinchuby justinchuby left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM thanks! Just a few final comments

@titaiwangms titaiwangms added the merge at lgtm Reviewers can merge when they approve label Nov 10, 2023
@justinchuby justinchuby merged commit 88ee668 into microsoft:main Nov 10, 2023
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
merge at lgtm Reviewers can merge when they approve module: torchlib Related to the torch/aten function lib in development
Projects
None yet
Development

Successfully merging this pull request may close these issues.

[torchlib] native_batch_norm needs more aten ops support
2 participants