Skip to content

Commit 25b05c2

Browse files
Olivia-liupytorchbot
authored andcommitted
ETRecord ser/de handling "None" outputs and more (#3039)
Summary: Pull Request resolved: #3039 For the ease of communication, let me assign nicknames to the files related to this diff: * File A: *caffe2/torch/_export/serde/serialize.py* * File B: *executorch/exir/serde/serialize.py* * File C: *executorch/exir/serde/export_serialize.py* Recently, we noticed that error `torch._export.serde.serialize.SerializeError: Unable to deserialize output node Argument(as_none=[])` (P1210590561) was thrown from File B when deserializing ETRecord. It's possible that the error has been there since the beginning, but we've just never tested that logic path. In this diff, I made a fix on File B to resolve this particular issue. Also adding handling for "None" output case in sdk logic. ***Keep on reading if you don't think the code changes make sense:*** I explored the history of file changes. In chronological order: 1. D48258552, `deserialize_graph_output()` was copied from File A to File B, with some modifications made. The `deserialize_graph_output()` in File B overrides that in File A due to polymorphism. 2. D52446586, File C was created by ***copying*** File A. As a result of this diff, the `deserialize_graph_output()` in File B now overrides that in File C. 3. Also in D52446586, the `deserialize_graph_output()` in File A had some significant changes; File C got the new version of `deserialize_graph_output()`. But this diff didn't update the `deserialize_graph_output()` in File B. 4. D55391674 added the handling for "None" outputs to File A. This diff brings (parts of) File C up-to-date with File A, and make `deserialize_graph_output()` in File B properly overrides that in File A. In the future, we should figure out how to keep File C and File A in sync. Recently, File C was broken because it didn't stay in sync with File A in D54855251 and had to be fixed by D55776877. There will be a design review session this Friday to discuss consolidating the serialization code for edge and export. Reviewed By: tarun292 Differential Revision: D56091104 fbshipit-source-id: 20c75ddc610c3be7ab2bb62943419d3b8b2be079 (cherry picked from commit 89cfa73)
1 parent 7b29ad2 commit 25b05c2

File tree

3 files changed

+29
-21
lines changed

3 files changed

+29
-21
lines changed

exir/serde/export_serialize.py

+7-2
Original file line numberDiff line numberDiff line change
@@ -1190,13 +1190,17 @@ def deserialize_tensor_meta(
11901190
),
11911191
)
11921192

1193-
def deserialize_graph_output(self, output) -> torch.fx.Node:
1193+
def deserialize_graph_output(self, output) -> Optional[Union[torch.fx.Node, int]]:
11941194
if output.type == "as_tensor":
11951195
return self.serialized_name_to_node[output.as_tensor.name]
11961196
elif output.type == "as_sym_int":
11971197
return self.serialized_name_to_node[output.as_sym_int.as_name]
11981198
elif output.type == "as_sym_bool":
11991199
return self.serialized_name_to_node[output.as_sym_bool.as_name]
1200+
elif output.type == "as_int":
1201+
return output.as_int
1202+
elif output.type == "as_none":
1203+
return None
12001204
else:
12011205
raise SerializeError(f"Unable to deserialize output node {output}")
12021206

@@ -1249,7 +1253,8 @@ def deserialize_graph(self, serialized_graph: Graph) -> torch.fx.Graph:
12491253
output_node.meta["val"] = output_node.args[0].meta["val"]
12501254
else:
12511255
output_node.meta["val"] = tuple(
1252-
arg.meta["val"] for arg in output_node.args[0]
1256+
arg.meta["val"] if isinstance(arg, torch.fx.Node) else arg
1257+
for arg in output_node.args[0]
12531258
)
12541259

12551260
return self.graph

exir/serde/serialize.py

+16-18
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,6 @@
3535
LoweredBackendModule as SerdeLoweredBackendModule,
3636
)
3737
from torch._export.serde.schema import SchemaVersion
38-
from torch._export.serde.serialize import SerializeError
3938
from torch._export.serde.union import _Union
4039
from torch._export.verifier import load_verifier
4140
from torch.fx.experimental import symbolic_shapes
@@ -479,23 +478,22 @@ def deserialize_metadata(self, metadata: Dict[str, str]) -> Dict[str, Any]:
479478

480479
return res
481480

482-
def deserialize_graph_output(self, output: schema.Argument) -> torch.fx.Node:
483-
if isinstance(output.value, schema.TensorArgument):
484-
if output.value.name in self.state_dict: # TODO(T157676982)
485-
val = self.state_dict[output.value.name]
486-
setattr(self.module, output.value.name, val)
487-
node = self.graph.create_node(
488-
"get_attr",
489-
output.value.name,
490-
name=output.value.name,
491-
)
492-
node.meta = {"val": ""}
493-
return node
494-
return self.serialized_name_to_node[output.value.name]
495-
elif isinstance(output.value, (schema.SymIntArgument, schema.SymBoolArgument)):
496-
return self.serialized_name_to_node[output.value.as_name]
497-
else:
498-
raise SerializeError(f"Unable to deserialize output node {output}")
481+
def deserialize_graph_output(
482+
self, output: schema.Argument
483+
) -> Optional[Union[torch.fx.Node, int]]:
484+
if (
485+
output.type == "as_tensor" and output.value.name in self.state_dict
486+
): # TODO(T157676982)
487+
val = self.state_dict[output.value.name]
488+
setattr(self.module, output.value.name, val)
489+
node = self.graph.create_node(
490+
"get_attr",
491+
output.value.name,
492+
name=output.value.name,
493+
)
494+
node.meta = {"val": ""}
495+
return node
496+
return super().deserialize_graph_output(output)
499497

500498
# pyre-ignore
501499
def deserialize_alloc_inputs(self, serialized_inputs: List[schema.NamedArgument]):

sdk/debug_format/et_schema.py

+6-1
Original file line numberDiff line numberDiff line change
@@ -260,7 +260,12 @@ def gen_operator_graph(
260260
assert len(args) == 1
261261
# Args of op=='output' is a wrapped list of return nodes ([ret_1, ret_2, ...], )
262262
in_nodes = [
263-
nodes[FXOperatorGraph._get_node_name(ret)] for ret in args[0]
263+
(
264+
nodes[FXOperatorGraph._get_node_name(ret)]
265+
if ret is not None
266+
else []
267+
)
268+
for ret in args[0]
264269
]
265270
node = ValueNode(
266271
name,

0 commit comments

Comments
 (0)