Skip to content

Commit 3af94a7

Browse files
authored
[IR] Handle invalid output deserialization (#2223)
Handle deserializing a graph if an output that is not produced by any nodes. This is discovered when working on microsoft/onnxruntime-genai#1416
1 parent 6d33d22 commit 3af94a7

File tree

2 files changed

+34
-2
lines changed

2 files changed

+34
-2
lines changed

onnxscript/ir/serde.py

Lines changed: 17 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -672,8 +672,23 @@ def _deserialize_graph(
672672
for node in proto.node
673673
]
674674

675-
# Fill in values for graph outputs
676-
outputs = [deserialize_value_info_proto(info, values[info.name]) for info in proto.output]
675+
outputs = []
676+
for info in proto.output:
677+
# Fill in values for graph outputs
678+
output_name = info.name
679+
if output_name not in values:
680+
# Handle (invalid) graph outputs that do not have any producers
681+
logger.warning(
682+
"Output '%s' is not produced by any node. The graph has an invalid output",
683+
output_name,
684+
)
685+
value = _core.Value(name=output_name)
686+
else:
687+
# A valid, normal graph output
688+
value = values[output_name]
689+
# Fill in shape/type information
690+
deserialize_value_info_proto(info, value)
691+
outputs.append(value)
677692

678693
# Exit the graph scope by popping the values for this scope from the stack
679694
scoped_values.pop()

onnxscript/ir/serde_test.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -290,6 +290,23 @@ def test_deserialize_graph_handles_unsorted_graph(self):
290290
self.assertEqual(deserialized_graph[0].op_type, "Op_1")
291291
self.assertEqual(deserialized_graph[1].op_type, "Op_0")
292292

293+
def test_deserialize_graph_handles_invalid_output(self):
294+
# The graph has an output that is not connected to any node, and it does not
295+
# have shape/type information.
296+
graph_with_invalid_output = ir.Graph(
297+
inputs=[],
298+
outputs=[ir.Value(name="invalid_output")],
299+
nodes=[],
300+
name="graph_with_invalid_output",
301+
)
302+
graph_proto = serde.serialize_graph(graph_with_invalid_output)
303+
deserialized_graph = serde.deserialize_graph(graph_proto)
304+
self.assertEqual(len(deserialized_graph.outputs), 1)
305+
self.assertEqual(deserialized_graph.outputs[0].name, "invalid_output")
306+
self.assertEqual(deserialized_graph.outputs[0].type, None)
307+
self.assertEqual(deserialized_graph.outputs[0].shape, None)
308+
self.assertEqual(deserialized_graph.outputs[0].dtype, None)
309+
293310

294311
class QuantizationAnnotationTest(unittest.TestCase):
295312
"""Test that quantization annotations are correctly serialized and deserialized."""

0 commit comments

Comments
 (0)