Skip to content

Add Op(_native_batch_norm_legit_no_training and _native_batch_norm_legit) | feat(torchlib) #1116

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged

Conversation

titaiwangms
Copy link
Contributor

Fix #817

Add the support of _native_batch_norm_legit_no_training and _native_batch_norm_legit, which are two new aten ops to replace aten::native_batch_norm according to https://github.com/pytorch/pytorch/blob/a44f8894fa6d973693aab44a3dda079a168b05c1/torch/_decomp/decompositions.py#L1501-L1510.

Previous to this PR, due to lack of support of _native_batch_norm_legit_no_training and _native_batch_norm_legit, the exporter decomposes native_batch_norm to a bunch of other nodes and drags down the performance.

NOTE: The mismatch result size between CUDA/CPU export doesn't happen even with these nodes supported. Could be fixed somewhere else.

Tested with the code:

import torch

import onnxruntime


def repro_split():
    class Model(torch.nn.Module):
        def __init__(self):
            super().__init__()
            self.bn = torch.nn.BatchNorm2d(64)
            self.conv = torch.nn.Conv2d(64, 64, 3)

        def forward(self, x):
            x = self.bn(x)
            x = self.conv(x)
            return torch.split(x, [16, 24, 24], 1)

    model = Model().cuda().eval()
    x = torch.randn(1, 64, 32, 32).cuda()
    export_output = torch.onnx.dynamo_export(model, x)

    onnxruntime.InferenceSession(export_output.model_proto.SerializeToString())
    export_output.save("coat_lite_mini.onnx")
    export_output.save_diagnostics("debug_bn.sarif")

    session = onnxruntime.InferenceSession("coat_lite_mini.onnx")
    input_names = [ort_input.name for ort_input in session.get_inputs()]
    onnx_format_args = export_output.adapt_torch_inputs_to_onnx(
        x
    )
    ort_input = {k: v.cpu().numpy() for k, v in zip(input_names, onnx_format_args)}
    print(session.run(None, ort_input))


repro_split()

@titaiwangms titaiwangms added the module: torchlib Related to the torch/aten function lib in development label Oct 27, 2023
@codecov
Copy link

codecov bot commented Oct 27, 2023

Codecov Report

Merging #1116 (6323b65) into main (70843ef) will increase coverage by 0.01%.
The diff coverage is 75.00%.

@@            Coverage Diff             @@
##             main    #1116      +/-   ##
==========================================
+ Coverage   78.44%   78.45%   +0.01%     
==========================================
  Files         118      118              
  Lines       15018    15021       +3     
  Branches     1599     1599              
==========================================
+ Hits        11781    11785       +4     
+ Misses       2870     2867       -3     
- Partials      367      369       +2     
Files Coverage Δ
onnxscript/function_libs/torch_lib/ops/core.py 79.74% <75.00%> (+0.05%) ⬆️

# replace native_batch_norm within unknown time period.
# TODO: Refactor this after native_batch_norm is deprecated.
@torch_op("aten::_native_batch_norm_legit_no_training", trace_only=True)
def aten_native_batch_norm_no_training(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

aten__native_batch_norm_no_training

# NOTE: https://github.com/pytorch/pytorch/blob/a44f8894fa6d973693aab44a3dda079a168b05c1/torch/_decomp/decompositions.py#L1501-L1510
# _native_batch_norm_legit_no_training and _native_batch_norm_legit are meant to
# replace native_batch_norm within unknown time period.
# TODO: Refactor this after native_batch_norm is deprecated.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Create an issue to track the todo?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@justinchuby
Copy link
Collaborator

@titaiwangms just minor follow ups - thanks!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
module: torchlib Related to the torch/aten function lib in development
Projects
None yet
Development

Successfully merging this pull request may close these issues.

[torchlib] aten batch_norm ops have different size results on "CPU" vs "CUDA"
3 participants