diff --git a/onnxscript/ir/serde.py b/onnxscript/ir/serde.py index bf39c1ea31..64703b2baa 100644 --- a/onnxscript/ir/serde.py +++ b/onnxscript/ir/serde.py @@ -672,8 +672,23 @@ def _deserialize_graph( for node in proto.node ] - # Fill in values for graph outputs - outputs = [deserialize_value_info_proto(info, values[info.name]) for info in proto.output] + outputs = [] + for info in proto.output: + # Fill in values for graph outputs + output_name = info.name + if output_name not in values: + # Handle (invalid) graph outputs that do not have any producers + logger.warning( + "Output '%s' is not produced by any node. The graph has an invalid output", + output_name, + ) + value = _core.Value(name=output_name) + else: + # A valid, normal graph output + value = values[output_name] + # Fill in shape/type information + deserialize_value_info_proto(info, value) + outputs.append(value) # Exit the graph scope by popping the values for this scope from the stack scoped_values.pop() diff --git a/onnxscript/ir/serde_test.py b/onnxscript/ir/serde_test.py index 416020afeb..303f02761f 100644 --- a/onnxscript/ir/serde_test.py +++ b/onnxscript/ir/serde_test.py @@ -290,6 +290,23 @@ def test_deserialize_graph_handles_unsorted_graph(self): self.assertEqual(deserialized_graph[0].op_type, "Op_1") self.assertEqual(deserialized_graph[1].op_type, "Op_0") + def test_deserialize_graph_handles_invalid_output(self): + # The graph has an output that is not connected to any node, and it does not + # have shape/type information. + graph_with_invalid_output = ir.Graph( + inputs=[], + outputs=[ir.Value(name="invalid_output")], + nodes=[], + name="graph_with_invalid_output", + ) + graph_proto = serde.serialize_graph(graph_with_invalid_output) + deserialized_graph = serde.deserialize_graph(graph_proto) + self.assertEqual(len(deserialized_graph.outputs), 1) + self.assertEqual(deserialized_graph.outputs[0].name, "invalid_output") + self.assertEqual(deserialized_graph.outputs[0].type, None) + self.assertEqual(deserialized_graph.outputs[0].shape, None) + self.assertEqual(deserialized_graph.outputs[0].dtype, None) + class QuantizationAnnotationTest(unittest.TestCase): """Test that quantization annotations are correctly serialized and deserialized."""