-
Notifications
You must be signed in to change notification settings - Fork 63
Add Op(_native_batch_norm_legit_no_training and _native_batch_norm_legit) | feat(torchlib) #1116
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add Op(_native_batch_norm_legit_no_training and _native_batch_norm_legit) | feat(torchlib) #1116
Conversation
Codecov Report
@@ Coverage Diff @@
## main #1116 +/- ##
==========================================
+ Coverage 78.44% 78.45% +0.01%
==========================================
Files 118 118
Lines 15018 15021 +3
Branches 1599 1599
==========================================
+ Hits 11781 11785 +4
+ Misses 2870 2867 -3
- Partials 367 369 +2
|
# replace native_batch_norm within unknown time period. | ||
# TODO: Refactor this after native_batch_norm is deprecated. | ||
@torch_op("aten::_native_batch_norm_legit_no_training", trace_only=True) | ||
def aten_native_batch_norm_no_training( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
aten__native_batch_norm_no_training
# NOTE: https://github.com/pytorch/pytorch/blob/a44f8894fa6d973693aab44a3dda079a168b05c1/torch/_decomp/decompositions.py#L1501-L1510 | ||
# _native_batch_norm_legit_no_training and _native_batch_norm_legit are meant to | ||
# replace native_batch_norm within unknown time period. | ||
# TODO: Refactor this after native_batch_norm is deprecated. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Create an issue to track the todo?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@titaiwangms just minor follow ups - thanks! |
From the review comment: #1116 (comment).
Fix #817
Add the support of
_native_batch_norm_legit_no_training
and_native_batch_norm_legit
, which are two new aten ops to replace aten::native_batch_norm according to https://github.com/pytorch/pytorch/blob/a44f8894fa6d973693aab44a3dda079a168b05c1/torch/_decomp/decompositions.py#L1501-L1510.Previous to this PR, due to lack of support of
_native_batch_norm_legit_no_training
and_native_batch_norm_legit
, the exporter decomposesnative_batch_norm
to a bunch of other nodes and drags down the performance.NOTE: The mismatch result size between CUDA/CPU export doesn't happen even with these nodes supported. Could be fixed somewhere else.
Tested with the code: