-
Notifications
You must be signed in to change notification settings - Fork 65
Add Ops(_native_batch_norm_legit_functional) | feat(torchlib) #1143
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add Ops(_native_batch_norm_legit_functional) | feat(torchlib) #1143
Conversation
Codecov Report
@@ Coverage Diff @@
## main #1143 +/- ##
==========================================
- Coverage 78.41% 78.41% -0.01%
==========================================
Files 118 118
Lines 15062 15125 +63
Branches 1607 1618 +11
==========================================
+ Hits 11811 11860 +49
- Misses 2880 2889 +9
- Partials 371 376 +5
|
Is it possible to create a custom op info for _native_batch_norm_legit_functional? |
Because it's not in https://github.com/pytorch/pytorch/blob/main/aten/src/ATen/native/native_functions.yaml, I thought it couldn't be invoked with torch.ops.aten. I added the tests. |
# NOTE: This op is invoked by PyTorch Functionalization, and not in | ||
# native_functions.yaml, It can be found in torch/_decomp/decompositions.py |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Should we file an issue in pytorch to include it in the yaml?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
bias = op.Expand(op.Constant(value_floats=[0.0]), op.Shape(input, start=1, end=2)) | ||
|
||
axes = list(range(len(input.shape))) | ||
axes.pop(1) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Possible to add a comment on why we pop index 1?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
hmm this code is duplicated from aten_native_batch_norm. They are only different from the output numbers. Maybe @xiaowuhu can chime in and answer this?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Might be confusing I am duplicating the code. Alternatively, we could add a higher level of traced function to merge these ops together, and use functional: bool
to differentiate them.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Feel free to add a todo for now and create an issue to track
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Might be confusing I am duplicating the code. Alternatively, we could add a higher level of traced function to merge these ops together, and use
functional: bool
to differentiate them.
I think it's fine for now.
# Cannot return 2 dup output, so have to do twice with different variable name | ||
empty_mean = op.CastLike(op.Shape(input, start=0, end=0), norm) | ||
empty_var = op.CastLike(op.Shape(input, start=0, end=0), norm) | ||
return norm, empty_mean, empty_var, running_mean, running_var |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Add a note saying we omitted computing mean and var so readers know to implement them when needed?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM thanks! Just a few final comments
Fix #1140
Add
(1)
aten::_native_batch_norm_legit.no_stats
(2)
aten::_to_copy
(3)
aten::_native_batch_norm_legit_functional
aten::_native_batch_norm_legit_functional
is only invoked by Functionalization pass, so it can't be tested in op_test. It will be added into op_test in converter side. The only difference btween the op andaten::_native_batch_norm_legit
is the output numbers.aten::_native_batch_norm_legit_functional
returns running_mean and running_var according to https://github.com/pytorch/pytorch/blob/1488bafb274fcc82c8aac429bad61738bc3f950e/torch/_decomp/decompositions.py#L1804-L1826aten_native_batch_norm_legit
is split into two sample inputs to separately feed into different ONNX variants, since they require different set of arguments.