Skip to content

ETRecord ser/de handling "None" outputs and more #3191

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Apr 22, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 7 additions & 2 deletions exir/serde/export_serialize.py
Original file line number Diff line number Diff line change
Expand Up @@ -1190,13 +1190,17 @@ def deserialize_tensor_meta(
),
)

def deserialize_graph_output(self, output) -> torch.fx.Node:
def deserialize_graph_output(self, output) -> Optional[Union[torch.fx.Node, int]]:
if output.type == "as_tensor":
return self.serialized_name_to_node[output.as_tensor.name]
elif output.type == "as_sym_int":
return self.serialized_name_to_node[output.as_sym_int.as_name]
elif output.type == "as_sym_bool":
return self.serialized_name_to_node[output.as_sym_bool.as_name]
elif output.type == "as_int":
return output.as_int
elif output.type == "as_none":
return None
else:
raise SerializeError(f"Unable to deserialize output node {output}")

Expand Down Expand Up @@ -1249,7 +1253,8 @@ def deserialize_graph(self, serialized_graph: Graph) -> torch.fx.Graph:
output_node.meta["val"] = output_node.args[0].meta["val"]
else:
output_node.meta["val"] = tuple(
arg.meta["val"] for arg in output_node.args[0]
arg.meta["val"] if isinstance(arg, torch.fx.Node) else arg
for arg in output_node.args[0]
)

return self.graph
Expand Down
34 changes: 16 additions & 18 deletions exir/serde/serialize.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,6 @@
LoweredBackendModule as SerdeLoweredBackendModule,
)
from torch._export.serde.schema import SchemaVersion
from torch._export.serde.serialize import SerializeError
from torch._export.serde.union import _Union
from torch._export.verifier import load_verifier
from torch.fx.experimental import symbolic_shapes
Expand Down Expand Up @@ -479,23 +478,22 @@ def deserialize_metadata(self, metadata: Dict[str, str]) -> Dict[str, Any]:

return res

def deserialize_graph_output(self, output: schema.Argument) -> torch.fx.Node:
if isinstance(output.value, schema.TensorArgument):
if output.value.name in self.state_dict: # TODO(T157676982)
val = self.state_dict[output.value.name]
setattr(self.module, output.value.name, val)
node = self.graph.create_node(
"get_attr",
output.value.name,
name=output.value.name,
)
node.meta = {"val": ""}
return node
return self.serialized_name_to_node[output.value.name]
elif isinstance(output.value, (schema.SymIntArgument, schema.SymBoolArgument)):
return self.serialized_name_to_node[output.value.as_name]
else:
raise SerializeError(f"Unable to deserialize output node {output}")
def deserialize_graph_output(
self, output: schema.Argument
) -> Optional[Union[torch.fx.Node, int]]:
if (
output.type == "as_tensor" and output.value.name in self.state_dict
): # TODO(T157676982)
val = self.state_dict[output.value.name]
setattr(self.module, output.value.name, val)
node = self.graph.create_node(
"get_attr",
output.value.name,
name=output.value.name,
)
node.meta = {"val": ""}
return node
return super().deserialize_graph_output(output)

# pyre-ignore
def deserialize_alloc_inputs(self, serialized_inputs: List[schema.NamedArgument]):
Expand Down
7 changes: 6 additions & 1 deletion sdk/debug_format/et_schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -260,7 +260,12 @@ def gen_operator_graph(
assert len(args) == 1
# Args of op=='output' is a wrapped list of return nodes ([ret_1, ret_2, ...], )
in_nodes = [
nodes[FXOperatorGraph._get_node_name(ret)] for ret in args[0]
(
nodes[FXOperatorGraph._get_node_name(ret)]
if ret is not None
else []
)
for ret in args[0]
]
node = ValueNode(
name,
Expand Down
Loading