From b10a530d20aed80a5b0fe1a14cde81a4920c7bef Mon Sep 17 00:00:00 2001 From: Justin Chu Date: Wed, 23 Apr 2025 15:23:56 -0700 Subject: [PATCH 1/3] Invalid output --- onnxscript/ir/serde.py | 23 +++++++++++++++++++++-- 1 file changed, 21 insertions(+), 2 deletions(-) diff --git a/onnxscript/ir/serde.py b/onnxscript/ir/serde.py index bf39c1ea31..54177a44b2 100644 --- a/onnxscript/ir/serde.py +++ b/onnxscript/ir/serde.py @@ -672,8 +672,27 @@ def _deserialize_graph( for node in proto.node ] - # Fill in values for graph outputs - outputs = [deserialize_value_info_proto(info, values[info.name]) for info in proto.output] + outputs = [] + for info in proto.output: + # Fill in values for graph outputs + output_name = info.name + if output_name not in values: + # Handle (invalid) graph outputs that do not have any producers + logger.warning( + "Output '%s' is not produced by any node. The graph has an invalid output", + output_name + ) + value = _core.Value(name=output_name) + # Fill in shape/type information + deserialize_value_info_proto(info, value) + if output_name in quantization_annotations: + _deserialize_quantization_annotation( + quantization_annotations[output_name], value + ) + else: + # A valid, normal graph output + value = values[output_name] + outputs.append(value) # Exit the graph scope by popping the values for this scope from the stack scoped_values.pop() From a57d6204556efeccbd109ec5cf80175d149f2478 Mon Sep 17 00:00:00 2001 From: Justin Chu Date: Wed, 23 Apr 2025 15:30:30 -0700 Subject: [PATCH 2/3] [IR] Handle invalid outputs --- onnxscript/ir/serde.py | 2 +- onnxscript/ir/serde_test.py | 17 +++++++++++++++++ 2 files changed, 18 insertions(+), 1 deletion(-) diff --git a/onnxscript/ir/serde.py b/onnxscript/ir/serde.py index 54177a44b2..b165536513 100644 --- a/onnxscript/ir/serde.py +++ b/onnxscript/ir/serde.py @@ -680,7 +680,7 @@ def _deserialize_graph( # Handle (invalid) graph outputs that do not have any producers logger.warning( "Output '%s' is not produced by any node. The graph has an invalid output", - output_name + output_name, ) value = _core.Value(name=output_name) # Fill in shape/type information diff --git a/onnxscript/ir/serde_test.py b/onnxscript/ir/serde_test.py index 416020afeb..303f02761f 100644 --- a/onnxscript/ir/serde_test.py +++ b/onnxscript/ir/serde_test.py @@ -290,6 +290,23 @@ def test_deserialize_graph_handles_unsorted_graph(self): self.assertEqual(deserialized_graph[0].op_type, "Op_1") self.assertEqual(deserialized_graph[1].op_type, "Op_0") + def test_deserialize_graph_handles_invalid_output(self): + # The graph has an output that is not connected to any node, and it does not + # have shape/type information. + graph_with_invalid_output = ir.Graph( + inputs=[], + outputs=[ir.Value(name="invalid_output")], + nodes=[], + name="graph_with_invalid_output", + ) + graph_proto = serde.serialize_graph(graph_with_invalid_output) + deserialized_graph = serde.deserialize_graph(graph_proto) + self.assertEqual(len(deserialized_graph.outputs), 1) + self.assertEqual(deserialized_graph.outputs[0].name, "invalid_output") + self.assertEqual(deserialized_graph.outputs[0].type, None) + self.assertEqual(deserialized_graph.outputs[0].shape, None) + self.assertEqual(deserialized_graph.outputs[0].dtype, None) + class QuantizationAnnotationTest(unittest.TestCase): """Test that quantization annotations are correctly serialized and deserialized.""" From 05cd67e704ccc60bc8afc307e3355adce46d22ba Mon Sep 17 00:00:00 2001 From: Justin Chu Date: Wed, 23 Apr 2025 15:50:25 -0700 Subject: [PATCH 3/3] deserialize_value_info_proto --- onnxscript/ir/serde.py | 8 ++------ 1 file changed, 2 insertions(+), 6 deletions(-) diff --git a/onnxscript/ir/serde.py b/onnxscript/ir/serde.py index b165536513..64703b2baa 100644 --- a/onnxscript/ir/serde.py +++ b/onnxscript/ir/serde.py @@ -683,15 +683,11 @@ def _deserialize_graph( output_name, ) value = _core.Value(name=output_name) - # Fill in shape/type information - deserialize_value_info_proto(info, value) - if output_name in quantization_annotations: - _deserialize_quantization_annotation( - quantization_annotations[output_name], value - ) else: # A valid, normal graph output value = values[output_name] + # Fill in shape/type information + deserialize_value_info_proto(info, value) outputs.append(value) # Exit the graph scope by popping the values for this scope from the stack