Skip to content

Missing aten._native_batch_norm_legit_no_training.default #1118

@thiagocrepaldi

Description

@thiagocrepaldi

The missing operator is part of the torch.compiler IR supported after pytorch/pytorch#111497

Test from /test/onnx/test_fx_to_onnx_with_onnxruntime.py when using torch.export.export(torchvision.models.resnet18(weights=None).eval()) as input to torch.onnx.dynamo_export

    @skip_if_no_torchvision
    def test_resnet18(self):
        # TODO(bowbao): Note [training vs eval in dynamo_export]
        # So we are effectively exporting all models in traning mode by
        # default. But for the sake of this export we are only interested in eval mode.
        # The question is, should we call `model.eval()` in `dynamo_export`?
        # This particular test fails 'functionalization' in training mode.
        # So we are explicitly calling `model.eval()` for any model that contains
        # batch norm.
        # Ref: https://github.com/pytorch/pytorch/issues/99662#issuecomment-1528178221
        model = torchvision.models.resnet18(weights=None).eval()
        dummy_input = torch.randn(1, 3, 224, 224)

        self.run_test_with_fx_to_onnx_exporter_and_onnx_runtime(
            model, (dummy_input,), model_type=self.model_type
        )

backtrace

Traceback (most recent call last):
  File "/opt/pytorch/torch/onnx/_internal/exporter.py", line 1228, in dynamo_export
    return Exporter(
  File "/opt/pytorch/torch/onnx/_internal/exporter.py", line 979, in export
    graph_module = self.options.fx_tracer.generate_fx(
  File "/opt/pytorch/torch/onnx/_internal/fx/torch_export_graph_extractor.py", line 60, in generate_fx
    return self.pre_export_passes(options, model, model.graph_module, updated_model_args)  # type: ignore[return-value]
  File "<@beartype(torch.onnx._internal.fx.torch_export_graph_extractor.TorchExport.pre_export_passes) at 0x7f22f35c9f70>", line 93, in pre_export_passes
  File "/opt/pytorch/torch/onnx/_internal/fx/torch_export_graph_extractor.py", line 79, in pre_export_passes
    analysis.UnsupportedFxNodesAnalysis(
  File "/opt/pytorch/torch/onnx/_internal/fx/analysis/unsupported_nodes.py", line 74, in analyze
    self._lint(analysis_result, diagnostic_level)
  File "/opt/pytorch/torch/onnx/_internal/fx/analysis/unsupported_nodes.py", line 38, in _lint
    self.diagnostic_context.log_and_raise_if_error(diagnostic)
  File "/opt/pytorch/torch/onnx/_internal/diagnostics/infra/context.py", line 367, in log_and_raise_if_error
    raise RuntimeErrorWithDiagnostic(diagnostic)
torch.onnx._internal.diagnostics.infra.context.RuntimeErrorWithDiagnostic: Unsupported FX nodes: {'call_function': ['aten._native_batch_norm_legit_no_training.default']}. 

The above exception was the direct cause of the following exception:

Traceback (most recent call last):
  File "/opt/conda/envs/ptca/lib/python3.8/unittest/case.py", line 60, in testPartExecutor
    yield
  File "/opt/conda/envs/ptca/lib/python3.8/unittest/case.py", line 676, in run
    self._callTestMethod(testMethod)
  File "/opt/conda/envs/ptca/lib/python3.8/unittest/case.py", line 633, in _callTestMethod
    method()
  File "/opt/pytorch/torch/testing/_internal/common_utils.py", line 2528, in wrapper
    method(*args, **kwargs)
  File "/opt/pytorch/test/onnx/test_fx_to_onnx_with_onnxruntime.py", line 280, in test_resnet18
    self.run_test_with_fx_to_onnx_exporter_and_onnx_runtime(
  File "<@beartype(onnx_test_common._TestONNXRuntime.run_test_with_fx_to_onnx_exporter_and_onnx_runtime) at 0x7f21a1f4d670>", line 260, in run_test_with_fx_to_onnx_exporter_and_onnx_runtime
  File "/opt/pytorch/test/onnx/onnx_test_common.py", line 279, in run_test_with_fx_to_onnx_exporter_and_onnx_runtime
    self.dynamic_shapes
  File "/opt/pytorch/test/onnx/onnx_test_common.py", line 279, in run_test_with_fx_to_onnx_exporter_and_onnx_runtime
    self.dynamic_shapes
  File "<@beartype(torch.onnx._internal.exporter.dynamo_export) at 0x7f2304b761f0>", line 51, in dynamo_export
  File "/opt/pytorch/torch/onnx/_internal/exporter.py", line 1244, in dynamo_export
    raise OnnxExporterError(
torch.onnx.OnnxExporterError: Failed to export the model to ONNX. Generating SARIF report at 'report_dynamo_export.sarif'. SARIF is a standard format for the output of static analysis tools. SARIF logs can be loaded in VS Code SARIF viewer extension, or SARIF web viewer (https://microsoft.github.io/sarif-web-component/). Please report a bug on PyTorch Github: https://github.com/pytorch/pytorch/issues

To execute this test, run the following from the base repo dir:
     python test/onnx/test_fx_to_onnx_with_onnxruntime.py -k test_resnet18

This message can be suppressed by setting PYTORCH_PRINT_REPRO_ON_FAILURE=0

Metadata

Metadata

Assignees

No one assigned

    Labels

    module: torchlibRelated to the torch/aten function lib in development

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions