Skip to content

Commit f245ac7

Browse files
zhxchen17facebook-github-bot
authored andcommitted
Support preserving calling convention to some modules.
Summary: X-link: pytorch/pytorch#106798 APS use this feature to swap out some submodules after unflattening. Reviewed By: kit1980, tugsbayasgalan Differential Revision: D48154341 fbshipit-source-id: 687393582596d679a51f236ae1665a5e07cdaadc
1 parent e023d8f commit f245ac7

File tree

3 files changed

+5
-1
lines changed

3 files changed

+5
-1
lines changed

exir/capture/_capture.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -212,6 +212,7 @@ def convert_to_fake(x):
212212
{},
213213
{},
214214
[],
215+
[],
215216
)
216217
return ExirExportedProgram(ep, False)
217218

exir/lowered_backend_module.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -233,6 +233,7 @@ def create_exported_program_from_submodule(
233233
# TODO(T159524653): fill in range and equality constraints
234234
range_constraints={},
235235
equality_constraints=[],
236+
module_call_graph=[],
236237
)
237238

238239

exir/serde/serialize.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -42,7 +42,7 @@ class GraphModuleSerializer(export_serialize.GraphModuleSerializer):
4242
def __init__(
4343
self, graph_signature: ep.ExportGraphSignature, call_spec: ep.CallSpec
4444
) -> None:
45-
super().__init__(graph_signature, call_spec)
45+
super().__init__(graph_signature, call_spec, [])
4646
self.state_dict: Dict[str, torch.Tensor] = {} # TODO(T157676982)
4747

4848
def serialize_operator(
@@ -574,6 +574,7 @@ def deserialize(
574574
graph_module,
575575
sig,
576576
call_spec,
577+
module_call_graph,
577578
symbol_name_to_symbol,
578579
) = GraphModuleDeserializer(state_dict).deserialize(
579580
serialized_exported_program.graph_module,
@@ -603,6 +604,7 @@ def deserialize(
603604
{}, # TODO(T157676982)
604605
range_constraints,
605606
equality_constraints,
607+
[],
606608
)
607609

608610

0 commit comments

Comments
 (0)