Skip to content

Commit 655f5d1

Browse files
Olivia-liufacebook-github-bot
authored andcommitted
ETRecord ser/de handling "None" outputs and more (#3039)
Summary: For the ease of communication, let me assign nicknames to the files related to this diff: * File A: *caffe2/torch/_export/serde/serialize.py* * File B: *executorch/exir/serde/serialize.py* * File C: *executorch/exir/serde/export_serialize.py* Recently, we noticed that error `torch._export.serde.serialize.SerializeError: Unable to deserialize output node Argument(as_none=[])` (P1210590561) was thrown from File B when deserializing ETRecord. It's possible that the error has been there since the beginning, but we've just never tested that logic path. In this diff, I made a fix on File B to resolve this particular issue. Also adding handling for "None" output case in sdk logic. ***Keep on reading if you don't think the code changes make sense:*** I explored the history of file changes. In chronological order: 1. D48258552, `deserialize_graph_output()` was copied from File A to File B, with some modifications made. The `deserialize_graph_output()` in File B overrides that in File A due to polymorphism. 2. D52446586, File C was created by ***copying*** File A. As a result of this diff, the `deserialize_graph_output()` in File B now overrides that in File C. 3. Also in D52446586, the `deserialize_graph_output()` in File A had some significant changes; File C got the new version of `deserialize_graph_output()`. But this diff didn't update the `deserialize_graph_output()` in File B. 4. D55391674 added the handling for "None" outputs to File A. This diff brings (parts of) File C up-to-date with File A, and make `deserialize_graph_output()` in File B properly overrides that in File A. In the future, we should figure out how to keep File C and File A in sync. Recently, File C was broken because it didn't stay in sync with File A in D54855251 and had to be fixed by D55776877. There will be a design review session this Friday to discuss consolidating the serialization code for edge and export. Reviewed By: tarun292 Differential Revision: D56091104
1 parent 7c81155 commit 655f5d1

File tree

3 files changed

+29
-20
lines changed

3 files changed

+29
-20
lines changed

exir/serde/export_serialize.py

+7-2
Original file line numberDiff line numberDiff line change
@@ -1190,13 +1190,17 @@ def deserialize_tensor_meta(
11901190
),
11911191
)
11921192

1193-
def deserialize_graph_output(self, output) -> torch.fx.Node:
1193+
def deserialize_graph_output(self, output) -> Optional[Union[torch.fx.Node, int]]:
11941194
if output.type == "as_tensor":
11951195
return self.serialized_name_to_node[output.as_tensor.name]
11961196
elif output.type == "as_sym_int":
11971197
return self.serialized_name_to_node[output.as_sym_int.as_name]
11981198
elif output.type == "as_sym_bool":
11991199
return self.serialized_name_to_node[output.as_sym_bool.as_name]
1200+
elif output.type == "as_int":
1201+
return output.as_int
1202+
elif output.type == "as_none":
1203+
return None
12001204
else:
12011205
raise SerializeError(f"Unable to deserialize output node {output}")
12021206

@@ -1249,7 +1253,8 @@ def deserialize_graph(self, serialized_graph: Graph) -> torch.fx.Graph:
12491253
output_node.meta["val"] = output_node.args[0].meta["val"]
12501254
else:
12511255
output_node.meta["val"] = tuple(
1252-
arg.meta["val"] for arg in output_node.args[0]
1256+
arg.meta["val"] if isinstance(arg, torch.fx.Node) else arg
1257+
for arg in output_node.args[0]
12531258
)
12541259

12551260
return self.graph

exir/serde/serialize.py

+16-17
Original file line numberDiff line numberDiff line change
@@ -484,23 +484,22 @@ def deserialize_metadata(self, metadata: Dict[str, str]) -> Dict[str, Any]:
484484

485485
return res
486486

487-
def deserialize_graph_output(self, output: schema.Argument) -> torch.fx.Node:
488-
if isinstance(output.value, schema.TensorArgument):
489-
if output.value.name in self.state_dict: # TODO(T157676982)
490-
val = self.state_dict[output.value.name]
491-
setattr(self.module, output.value.name, val)
492-
node = self.graph.create_node(
493-
"get_attr",
494-
output.value.name,
495-
name=output.value.name,
496-
)
497-
node.meta = {"val": ""}
498-
return node
499-
return self.serialized_name_to_node[output.value.name]
500-
elif isinstance(output.value, (schema.SymIntArgument, schema.SymBoolArgument)):
501-
return self.serialized_name_to_node[output.value.as_name]
502-
else:
503-
raise SerializeError(f"Unable to deserialize output node {output}")
487+
def deserialize_graph_output(
488+
self, output: schema.Argument
489+
) -> Optional[Union[torch.fx.Node, int]]:
490+
if (
491+
output.type == "as_tensor" and output.value.name in self.state_dict
492+
): # TODO(T157676982)
493+
val = self.state_dict[output.value.name]
494+
setattr(self.module, output.value.name, val)
495+
node = self.graph.create_node(
496+
"get_attr",
497+
output.value.name,
498+
name=output.value.name,
499+
)
500+
node.meta = {"val": ""}
501+
return node
502+
return super().deserialize_graph_output(output)
504503

505504
# pyre-ignore
506505
def deserialize_alloc_inputs(self, serialized_inputs: List[schema.NamedArgument]):

sdk/debug_format/et_schema.py

+6-1
Original file line numberDiff line numberDiff line change
@@ -260,7 +260,12 @@ def gen_operator_graph(
260260
assert len(args) == 1
261261
# Args of op=='output' is a wrapped list of return nodes ([ret_1, ret_2, ...], )
262262
in_nodes = [
263-
nodes[FXOperatorGraph._get_node_name(ret)] for ret in args[0]
263+
(
264+
nodes[FXOperatorGraph._get_node_name(ret)]
265+
if ret is not None
266+
else []
267+
)
268+
for ret in args[0]
264269
]
265270
node = ValueNode(
266271
name,

0 commit comments

Comments
 (0)