-
Notifications
You must be signed in to change notification settings - Fork 64
[torchlib] aten batch_norm ops have different size results on "CPU" vs "CUDA" #817
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Comments
Node looks correct. @gramalingam do you know if there is current limitation with shape/type inferencing function with sequence output? |
Could you share the onnx file? |
Updated in description. |
Finally got a minified repro, this one is not easy. Updated in description. |
It might suit more to put this issue under pytorch repo. But since we might need to update both repos anyway I will keep it here. |
So the exported graph with cuda is different? |
Intriguing. The graph is almost identical, except for one mismatch in dimension size inside value_info_proto. Turns out the root cause is different impl for I have an unsuccessful attempt at pytorch/pytorch#105352 aligning the impls. However it seems it might bc break too many downstream packages, so for a leaner solution we are back handling it in exporter. @justinchuby Can we override torch's shape/type with onnx shape type inference in the final model proto? |
Do we need to / is it possible to create a cuda specify function for it? We can go in and edit the model but onnx shape inference tend to create more issues than it solves so I wouldn't rely on it |
Op name and signature are device agnostic. This approach asks dispatcher to be able to dispatch by tensor args device types. If we can't think of anything else more systematic, the last resort is an fx pass to re-fakerun batchnorm with cpu device and update the metainfo. |
When the cuda shape is different, do the ops using that output handle the different shape? I guess my question is why are both shapes valid? |
If not training, these outputs from this low level After spending some effort digging in the pytorch/aten space, I discovered a long list of batch norm ops. The related ones are
There was motion to promote the Then why does this matter to us? Because with the side effect of pytorch/pytorch#105764, exporter will start to dispatch to |
Renaming the issue and moving to low priority. This issue is unblocked as explained. However, it will resurface if one decides to add torchlib implementation for all aten batchnorm ops and overloads. |
This came back to bite us on performance side lol. @justinchuby since the bn op is decomped, we don't get |
Would implementing _native_batch_norm_legit_no_training help? Or other ops? |
I think it (impl whichever the aten op it is) helps, then the issue in original post may come back (for cuda), we might need to hack around it. |
Are the new aten ops still different in shape with different devices? |
Not sure if it is fixed, but it certainly wasn't high priority on pytorch's list. |
…git) | feat(torchlib) (#1116) Fix #817 Add the support of `_native_batch_norm_legit_no_training` and `_native_batch_norm_legit`, which are two new aten ops to replace aten::native_batch_norm according to https://github.com/pytorch/pytorch/blob/a44f8894fa6d973693aab44a3dda079a168b05c1/torch/_decomp/decompositions.py#L1501-L1510. Previous to this PR, due to lack of support of `_native_batch_norm_legit_no_training` and `_native_batch_norm_legit`, the exporter decomposes `native_batch_norm` to a bunch of other nodes and drags down the performance. NOTE: The mismatch result size between CUDA/CPU export doesn't happen even with these nodes supported. Could be fixed somewhere else. Tested with the code: ```python import torch import onnxruntime def repro_split(): class Model(torch.nn.Module): def __init__(self): super().__init__() self.bn = torch.nn.BatchNorm2d(64) self.conv = torch.nn.Conv2d(64, 64, 3) def forward(self, x): x = self.bn(x) x = self.conv(x) return torch.split(x, [16, 24, 24], 1) model = Model().cuda().eval() x = torch.randn(1, 64, 32, 32).cuda() export_output = torch.onnx.dynamo_export(model, x) onnxruntime.InferenceSession(export_output.model_proto.SerializeToString()) export_output.save("coat_lite_mini.onnx") export_output.save_diagnostics("debug_bn.sarif") session = onnxruntime.InferenceSession("coat_lite_mini.onnx") input_names = [ort_input.name for ort_input in session.get_inputs()] onnx_format_args = export_output.adapt_torch_inputs_to_onnx( x ) ort_input = {k: v.cpu().numpy() for k, v in zip(input_names, onnx_format_args)} print(session.run(None, ort_input)) repro_split() ```
From bench
5 like: [ONNXRuntimeError] : 1 : FAIL : Load model from bench_dynamo_onnx_model/coat_lite_mini/model.onnx failed:Node (aten_split_with_sizes_273) output arg (289) type inference failed
Update: minimized repro. NOTE: only repros if exported in cuda
Example coat_lite_mini.onnx
The text was updated successfully, but these errors were encountered: