Closed
Description
Seems to be an issue with const lifting.
import torch
class CondModel(torch.nn.Module):
def forward(self, x):
z = torch.ones_like(x)
def true_fn(x, z):
x = x + 1.0
z = z * 1.0
return x, z
def false_fn(x, z):
x = x - 1.0
z = z * 0.0
return x, z
x = torch.cond(x.sum() > 0, true_fn, false_fn, (x, z))
return x, z
onnx_program = torch.onnx.export(CondModel(), (torch.tensor([1, 2]),), optimize=True, dynamo=True)
Traceback (most recent call last):
File "/home/justinchu/dev/onnxscript/onnxscript/ir/passes/_pass_infra.py", line 211, in call
pass_result = pass_(model)
^^^^^^^^^^^^
File "/home/justinchu/dev/onnxscript/onnxscript/ir/passes/_pass_infra.py", line 122, in __call__
result = self.call(model)
^^^^^^^^^^^^^^^^
File "/home/justinchu/dev/onnxscript/onnxscript/ir/passes/common/constant_manipulation.py", line 59, in call
node.graph.remove(node, safe=True)
File "/home/justinchu/dev/onnxscript/onnxscript/ir/_core.py", line 2124, in remove
_check_node_safe_to_remove(node, nodes_set, graph_outputs)
File "/home/justinchu/dev/onnxscript/onnxscript/ir/_core.py", line 1839, in _check_node_safe_to_remove
raise ValueError(
ValueError: Node 'Node(name='node_Constant_11', domain='', op_type='Constant', inputs=(), attributes=OrderedDict({'value': Attr('value', TENSOR, Tensor<INT64,[2]>(array([1, 1]), name='ones_like'))}), overload='', outputs=(Value('ones_like', type=Tensor(INT64), shape=[2], producer=node_Constant_11, index=0),), version=None, doc_string=None)' is still an output of the graph and cannot be removed when safe=True.
The above exception was the direct cause of the following exception:
Traceback (most recent call last):
File "/home/justinchu/dev/pytorch/torch/testing/_internal/common_utils.py", line 3156, in wrapper
method(*args, **kwargs)
File "/home/justinchu/dev/pytorch/test/onnx/exporter/test_small_models_e2e.py", line 153, in test_onnx_export_control_flow_multi_outputs
onnx_program = self.export(CondModel(), (torch.tensor([1, 2]),))
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/justinchu/dev/pytorch/test/onnx/exporter/test_small_models_e2e.py", line 20, in export
onnx_program = torch.onnx.export(
^^^^^^^^^^^^^^^^^^
File "/home/justinchu/dev/pytorch/torch/onnx/__init__.py", line 367, in export
return _compat.export_compat(
^^^^^^^^^^^^^^^^^^^^^^
File "/home/justinchu/dev/pytorch/torch/onnx/_internal/exporter/_compat.py", line 183, in export_compat
onnx_program.optimize()
File "/home/justinchu/dev/pytorch/torch/onnx/_internal/exporter/_onnx_program.py", line 208, in optimize
self.model = onnxscript_apis.optimize(self.model)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/justinchu/dev/onnxscript/onnxscript/_framework_apis/torch_2_6.py", line 30, in optimize
optimizer.optimize_ir(model)
File "/home/justinchu/dev/onnxscript/onnxscript/optimizer/_optimizer.py", line 59, in optimize_ir
result = optimizer_pass(model)
^^^^^^^^^^^^^^^^^^^^^
File "/home/justinchu/dev/onnxscript/onnxscript/ir/passes/_pass_infra.py", line 122, in __call__
result = self.call(model)
^^^^^^^^^^^^^^^^
File "/home/justinchu/dev/onnxscript/onnxscript/ir/passes/_pass_infra.py", line 214, in call
raise PassError(
onnxscript.ir.passes.PassError: An error occurred when running the '<onnxscript.ir.passes.common.constant_manipulation.LiftConstantsToInitializersPass object at 0x7f40be33e750>' pass after the following passes: ['<onnxscript.optimizer._inliner.InlinePass object at 0x7f40be6746b0>', '<onnxscript.ir.passes.PassManager object at 0x7f40be33d1c0>', '<onnxscript.ir.passes.common.unused_removal.RemoveUnusedNodesPass object at 0x7f40be33d190>']
cc @titaiwangms