File tree Expand file tree Collapse file tree 2 files changed +34
-2
lines changed Expand file tree Collapse file tree 2 files changed +34
-2
lines changed Original file line number Diff line number Diff line change @@ -672,8 +672,23 @@ def _deserialize_graph(
672
672
for node in proto .node
673
673
]
674
674
675
- # Fill in values for graph outputs
676
- outputs = [deserialize_value_info_proto (info , values [info .name ]) for info in proto .output ]
675
+ outputs = []
676
+ for info in proto .output :
677
+ # Fill in values for graph outputs
678
+ output_name = info .name
679
+ if output_name not in values :
680
+ # Handle (invalid) graph outputs that do not have any producers
681
+ logger .warning (
682
+ "Output '%s' is not produced by any node. The graph has an invalid output" ,
683
+ output_name ,
684
+ )
685
+ value = _core .Value (name = output_name )
686
+ else :
687
+ # A valid, normal graph output
688
+ value = values [output_name ]
689
+ # Fill in shape/type information
690
+ deserialize_value_info_proto (info , value )
691
+ outputs .append (value )
677
692
678
693
# Exit the graph scope by popping the values for this scope from the stack
679
694
scoped_values .pop ()
Original file line number Diff line number Diff line change @@ -290,6 +290,23 @@ def test_deserialize_graph_handles_unsorted_graph(self):
290
290
self .assertEqual (deserialized_graph [0 ].op_type , "Op_1" )
291
291
self .assertEqual (deserialized_graph [1 ].op_type , "Op_0" )
292
292
293
+ def test_deserialize_graph_handles_invalid_output (self ):
294
+ # The graph has an output that is not connected to any node, and it does not
295
+ # have shape/type information.
296
+ graph_with_invalid_output = ir .Graph (
297
+ inputs = [],
298
+ outputs = [ir .Value (name = "invalid_output" )],
299
+ nodes = [],
300
+ name = "graph_with_invalid_output" ,
301
+ )
302
+ graph_proto = serde .serialize_graph (graph_with_invalid_output )
303
+ deserialized_graph = serde .deserialize_graph (graph_proto )
304
+ self .assertEqual (len (deserialized_graph .outputs ), 1 )
305
+ self .assertEqual (deserialized_graph .outputs [0 ].name , "invalid_output" )
306
+ self .assertEqual (deserialized_graph .outputs [0 ].type , None )
307
+ self .assertEqual (deserialized_graph .outputs [0 ].shape , None )
308
+ self .assertEqual (deserialized_graph .outputs [0 ].dtype , None )
309
+
293
310
294
311
class QuantizationAnnotationTest (unittest .TestCase ):
295
312
"""Test that quantization annotations are correctly serialized and deserialized."""
You can’t perform that action at this time.
0 commit comments