Skip to content

[torchlib] aten batch_norm ops have different size results on "CPU" vs "CUDA" #817

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
BowenBao opened this issue Jun 29, 2023 · 17 comments · Fixed by #1116
Closed

[torchlib] aten batch_norm ops have different size results on "CPU" vs "CUDA" #817

BowenBao opened this issue Jun 29, 2023 · 17 comments · Fixed by #1116
Assignees
Labels
bug Something isn't working module: torchlib Related to the torch/aten function lib in development

Comments

@BowenBao
Copy link
Contributor

BowenBao commented Jun 29, 2023

From bench

5 like: [ONNXRuntimeError] : 1 : FAIL : Load model from bench_dynamo_onnx_model/coat_lite_mini/model.onnx failed:Node (aten_split_with_sizes_273) output arg (289) type inference failed

        coat_lite_mini
        mixnet_l
        eca_botnext26ts_256
        sebotnet33ts_256
        tf_mixnet_l
        botnet26t_256

Update: minimized repro. NOTE: only repros if exported in cuda

import torch
import onnxruntime

def repro_split():
    class Model(torch.nn.Module):
        def __init__(self):
            super().__init__()
            self.bn = torch.nn.BatchNorm2d(64)
            self.conv = torch.nn.Conv2d(64, 64, 3)

        def forward(self, x):
            x = self.bn(x)
            x = self.conv(x)
            return torch.split(x, [16, 24, 24], 1)

    model = Model().cuda().eval()
    x = torch.randn(1, 64, 32, 32).cuda()
    export_output = torch.onnx.dynamo_export(model, x)

    onnxruntime.InferenceSession(export_output.model_proto.SerializeToString())

repro_split()

Example coat_lite_mini.onnx

import onnxruntime

# onnxruntime.capi.onnxruntime_pybind11_state.Fail: [ONNXRuntimeError] : 1 : FAIL : Load model from coat_lite_mini.onnx failed:Node (aten_split_with_sizes_273) output arg (289) type inference failed
session = onnxruntime.InferenceSession("coat_lite_mini.onnx")
@BowenBao BowenBao added bug Something isn't working module: torchlib Related to the torch/aten function lib in development labels Jun 29, 2023
@BowenBao
Copy link
Contributor Author

image

Node looks correct. @gramalingam do you know if there is current limitation with shape/type inferencing function with sequence output?

@justinchuby
Copy link
Collaborator

Could you share the onnx file?

@BowenBao
Copy link
Contributor Author

Could you share the onnx file?

Updated in description.

@BowenBao
Copy link
Contributor Author

Finally got a minified repro, this one is not easy. Updated in description.

@BowenBao
Copy link
Contributor Author

It might suit more to put this issue under pytorch repo. But since we might need to update both repos anyway I will keep it here.

@justinchuby
Copy link
Collaborator

So the exported graph with cuda is different?

@BowenBao
Copy link
Contributor Author

BowenBao commented Jul 18, 2023

Intriguing. The graph is almost identical, except for one mismatch in dimension size inside value_info_proto. Turns out the root cause is different impl for native_batch_norm cpu vs cuda, as discussed in pytorch/pytorch#100985. E.g., cuda export will mark mean output shape as {num_features}, where as cpu export & torchlib impl produces {0}.

I have an unsuccessful attempt at pytorch/pytorch#105352 aligning the impls. However it seems it might bc break too many downstream packages, so for a leaner solution we are back handling it in exporter.

@justinchuby Can we override torch's shape/type with onnx shape type inference in the final model proto?

@justinchuby
Copy link
Collaborator

Do we need to / is it possible to create a cuda specify function for it? We can go in and edit the model but onnx shape inference tend to create more issues than it solves so I wouldn't rely on it

@BowenBao
Copy link
Contributor Author

is it possible to create a cuda specify function for it?

Op name and signature are device agnostic. This approach asks dispatcher to be able to dispatch by tensor args device types.

If we can't think of anything else more systematic, the last resort is an fx pass to re-fakerun batchnorm with cpu device and update the metainfo.

@justinchuby
Copy link
Collaborator

When the cuda shape is different, do the ops using that output handle the different shape? I guess my question is why are both shapes valid?

@BowenBao
Copy link
Contributor Author

If not training, these outputs from this low level aten::native_batch_norm is unused. If training, they are used, and their shapes are the same.

After spending some effort digging in the pytorch/aten space, I discovered a long list of batch norm ops. The related ones are

native_batch_norm
_native_batch_norm_legit
_native_batch_norm_legit.no_stats
_native_batch_norm_legit_no_training

There was motion to promote the legit ones as the public batch norm op in core IR, however there hasn't been updates for a while. And regardless of that, the legit ones also have this cpu/cuda mismatch behavior, probably for bc reason, so we are not better off with it.

Then why does this matter to us? Because with the side effect of pytorch/pytorch#105764, exporter will start to dispatch to _native_batch_norm_legit_no_training, which is the correct behavior. And since we don't have a torchlib op written for it, _native_batch_norm_legit_no_training gets decomposed. And the good news is, we don't have this shape issue anymore.

@BowenBao
Copy link
Contributor Author

Renaming the issue and moving to low priority. This issue is unblocked as explained.

However, it will resurface if one decides to add torchlib implementation for all aten batchnorm ops and overloads.

@BowenBao BowenBao changed the title ONNXRuntimeError: Node (aten_split_with_sizes_273) output arg (289) type inference failed [torchlib] aten batch_norm ops have different size results on "CPU" vs "CUDA" Jul 31, 2023
@BowenBao
Copy link
Contributor Author

This came back to bite us on performance side lol.

@justinchuby since the bn op is decomped, we don't get onnx::BatchNormalization in our final model.

@justinchuby
Copy link
Collaborator

Would implementing _native_batch_norm_legit_no_training help? Or other ops?

@BowenBao
Copy link
Contributor Author

I think it (impl whichever the aten op it is) helps, then the issue in original post may come back (for cuda), we might need to hack around it.

@justinchuby
Copy link
Collaborator

Are the new aten ops still different in shape with different devices?

@titaiwangms titaiwangms self-assigned this Oct 26, 2023
@BowenBao
Copy link
Contributor Author

Not sure if it is fixed, but it certainly wasn't high priority on pytorch's list.

titaiwangms added a commit that referenced this issue Oct 28, 2023
…git) | feat(torchlib) (#1116)

Fix #817

Add the support of `_native_batch_norm_legit_no_training` and
`_native_batch_norm_legit`, which are two new aten ops to replace
aten::native_batch_norm according to
https://github.com/pytorch/pytorch/blob/a44f8894fa6d973693aab44a3dda079a168b05c1/torch/_decomp/decompositions.py#L1501-L1510.

Previous to this PR, due to lack of support of
`_native_batch_norm_legit_no_training` and `_native_batch_norm_legit`,
the exporter decomposes `native_batch_norm` to a bunch of other nodes
and drags down the performance.

NOTE: The mismatch result size between CUDA/CPU export doesn't happen
even with these nodes supported. Could be fixed somewhere else.

Tested with the code:

```python
import torch

import onnxruntime


def repro_split():
    class Model(torch.nn.Module):
        def __init__(self):
            super().__init__()
            self.bn = torch.nn.BatchNorm2d(64)
            self.conv = torch.nn.Conv2d(64, 64, 3)

        def forward(self, x):
            x = self.bn(x)
            x = self.conv(x)
            return torch.split(x, [16, 24, 24], 1)

    model = Model().cuda().eval()
    x = torch.randn(1, 64, 32, 32).cuda()
    export_output = torch.onnx.dynamo_export(model, x)

    onnxruntime.InferenceSession(export_output.model_proto.SerializeToString())
    export_output.save("coat_lite_mini.onnx")
    export_output.save_diagnostics("debug_bn.sarif")

    session = onnxruntime.InferenceSession("coat_lite_mini.onnx")
    input_names = [ort_input.name for ort_input in session.get_inputs()]
    onnx_format_args = export_output.adapt_torch_inputs_to_onnx(
        x
    )
    ort_input = {k: v.cpu().numpy() for k, v in zip(input_names, onnx_format_args)}
    print(session.run(None, ort_input))


repro_split()
```
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working module: torchlib Related to the torch/aten function lib in development
Projects
None yet
3 participants