Skip to content

Optimizer fails on multi-out control flows #2184

Closed
@justinchuby

Description

@justinchuby

Seems to be an issue with const lifting.

import torch

class CondModel(torch.nn.Module):
    def forward(self, x):
        z = torch.ones_like(x)

        def true_fn(x, z):
            x = x + 1.0
            z = z * 1.0
            return x, z

        def false_fn(x, z):
            x = x - 1.0
            z = z * 0.0
            return x, z

        x = torch.cond(x.sum() > 0, true_fn, false_fn, (x, z))
        return x, z

onnx_program = torch.onnx.export(CondModel(), (torch.tensor([1, 2]),), optimize=True, dynamo=True)
Traceback (most recent call last):
  File "/home/justinchu/dev/onnxscript/onnxscript/ir/passes/_pass_infra.py", line 211, in call
    pass_result = pass_(model)
                  ^^^^^^^^^^^^
  File "/home/justinchu/dev/onnxscript/onnxscript/ir/passes/_pass_infra.py", line 122, in __call__
    result = self.call(model)
             ^^^^^^^^^^^^^^^^
  File "/home/justinchu/dev/onnxscript/onnxscript/ir/passes/common/constant_manipulation.py", line 59, in call
    node.graph.remove(node, safe=True)
  File "/home/justinchu/dev/onnxscript/onnxscript/ir/_core.py", line 2124, in remove
    _check_node_safe_to_remove(node, nodes_set, graph_outputs)
  File "/home/justinchu/dev/onnxscript/onnxscript/ir/_core.py", line 1839, in _check_node_safe_to_remove
    raise ValueError(
ValueError: Node 'Node(name='node_Constant_11', domain='', op_type='Constant', inputs=(), attributes=OrderedDict({'value': Attr('value', TENSOR, Tensor<INT64,[2]>(array([1, 1]), name='ones_like'))}), overload='', outputs=(Value('ones_like', type=Tensor(INT64), shape=[2], producer=node_Constant_11, index=0),), version=None, doc_string=None)' is still an output of the graph and cannot be removed when safe=True.

The above exception was the direct cause of the following exception:

Traceback (most recent call last):
  File "/home/justinchu/dev/pytorch/torch/testing/_internal/common_utils.py", line 3156, in wrapper
    method(*args, **kwargs)
  File "/home/justinchu/dev/pytorch/test/onnx/exporter/test_small_models_e2e.py", line 153, in test_onnx_export_control_flow_multi_outputs
    onnx_program = self.export(CondModel(), (torch.tensor([1, 2]),))
                   ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/justinchu/dev/pytorch/test/onnx/exporter/test_small_models_e2e.py", line 20, in export
    onnx_program = torch.onnx.export(
                   ^^^^^^^^^^^^^^^^^^
  File "/home/justinchu/dev/pytorch/torch/onnx/__init__.py", line 367, in export
    return _compat.export_compat(
           ^^^^^^^^^^^^^^^^^^^^^^
  File "/home/justinchu/dev/pytorch/torch/onnx/_internal/exporter/_compat.py", line 183, in export_compat
    onnx_program.optimize()
  File "/home/justinchu/dev/pytorch/torch/onnx/_internal/exporter/_onnx_program.py", line 208, in optimize
    self.model = onnxscript_apis.optimize(self.model)
                 ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/justinchu/dev/onnxscript/onnxscript/_framework_apis/torch_2_6.py", line 30, in optimize
    optimizer.optimize_ir(model)
  File "/home/justinchu/dev/onnxscript/onnxscript/optimizer/_optimizer.py", line 59, in optimize_ir
    result = optimizer_pass(model)
             ^^^^^^^^^^^^^^^^^^^^^
  File "/home/justinchu/dev/onnxscript/onnxscript/ir/passes/_pass_infra.py", line 122, in __call__
    result = self.call(model)
             ^^^^^^^^^^^^^^^^
  File "/home/justinchu/dev/onnxscript/onnxscript/ir/passes/_pass_infra.py", line 214, in call
    raise PassError(
onnxscript.ir.passes.PassError: An error occurred when running the '<onnxscript.ir.passes.common.constant_manipulation.LiftConstantsToInitializersPass object at 0x7f40be33e750>' pass after the following passes: ['<onnxscript.optimizer._inliner.InlinePass object at 0x7f40be6746b0>', '<onnxscript.ir.passes.PassManager object at 0x7f40be33d1c0>', '<onnxscript.ir.passes.common.unused_removal.RemoveUnusedNodesPass object at 0x7f40be33d190>']

cc @titaiwangms

Metadata

Metadata

Assignees

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions