5
5
from collections import namedtuple
6
6
from dataclasses import dataclass , field
7
7
from typing import Any , Callable , Dict , List , Optional , Set , Tuple , Union
8
+ from unittest .mock import patch
8
9
9
10
import sympy
10
11
import torch
12
+ import torch ._export
11
13
from executorch .exir .dynamic_shape import DynamicMemoryPlanningMode
12
14
from executorch .exir .emit import emit_program , EmitterOutput
13
15
from executorch .exir .error import ExportError , ExportErrorType , InternalError
32
34
from executorch .exir .schema import Program
33
35
from executorch .exir .serialize import serialize_to_flatbuffer
34
36
from executorch .exir .tracer import (
37
+ _default_decomposition_table ,
35
38
dispatch_trace ,
36
39
dynamo_trace ,
37
40
ExirDynamoConfig ,
48
51
from torch ._dynamo .eval_frame import Constraint
49
52
from torch ._export import CallSpec , export , ExportGraphSignature
50
53
from torch ._export .exported_program import ExportedProgram
54
+ from torch ._export .passes import ReplaceViewOpsWithViewCopyOpsPass
51
55
from torch ._export .passes .add_runtime_assertions_for_constraints_pass import (
52
56
InputDim ,
53
57
RangeConstraint ,
56
60
from torch .fx ._compatibility import compatibility
57
61
from torch .fx .experimental .proxy_tensor import make_fx
58
62
from torch .fx .experimental .symbolic_shapes import ShapeEnv
63
+ from torch .fx .graph import _PyTreeCodeGen , _PyTreeInfo
59
64
from torch .utils import _pytree as pytree
60
65
61
66
62
67
Val = Any
63
68
64
69
70
+ def _unlift (gm , inp_pos_to_param_buffer_name , in_spec , out_spec , state_dict ):
71
+ count = 0
72
+ # Step 1: make lifted params as get_attr
73
+ for node in gm .graph .nodes :
74
+ if node .op == "placeholder" :
75
+ if count in inp_pos_to_param_buffer_name :
76
+ with gm .graph .inserting_after (node ):
77
+ getattr_node = gm .graph .get_attr (
78
+ inp_pos_to_param_buffer_name [count ]
79
+ )
80
+ node .replace_all_uses_with (getattr_node )
81
+ metadata = node .meta
82
+ gm .graph .erase_node (node )
83
+ getattr_node .meta = metadata
84
+ count += 1
85
+
86
+ # Step 2: Fix the input/output of the graph now that we deleted
87
+ # some args.
88
+ gm .graph .lint ()
89
+ names = [f"arg_{ i } " for i in range (len (in_spec .children_specs ))]
90
+ gm .graph ._codegen = _PyTreeCodeGen (
91
+ _PyTreeInfo (
92
+ names ,
93
+ in_spec ,
94
+ out_spec ,
95
+ )
96
+ )
97
+ gm .recompile ()
98
+
99
+ # Step 3: Find state references in HigherOrderOps and recursively
100
+ # fix them.
101
+ for node in gm .graph .nodes :
102
+ if node .op == "call_function" and node .target == torch .ops .cond :
103
+ pred , true_graph , false_graph , operands = node .args
104
+ true_gm = getattr (gm , true_graph .name )
105
+ false_gm = getattr (gm , false_graph .name )
106
+ inp_pos_to_param_buffer_name_for_submod = {}
107
+ real_operands = []
108
+ for ix , operand in enumerate (operands ):
109
+ if operand .target in inp_pos_to_param_buffer_name .values ():
110
+ inp_pos_to_param_buffer_name_for_submod [ix ] = operand .target
111
+ true_gm .register_buffer (operand .target , state_dict [operand .target ])
112
+ false_gm .register_buffer (operand .target , state_dict [operand .target ])
113
+ else :
114
+ real_operands .append (operand )
115
+ node .args = (pred , true_graph , false_graph , real_operands )
116
+
117
+ _ , in_spec = pytree .tree_flatten (real_operands )
118
+
119
+ _unlift (
120
+ true_gm ,
121
+ inp_pos_to_param_buffer_name_for_submod ,
122
+ in_spec ,
123
+ None ,
124
+ state_dict ,
125
+ )
126
+ _unlift (
127
+ false_gm ,
128
+ inp_pos_to_param_buffer_name_for_submod ,
129
+ in_spec ,
130
+ None ,
131
+ state_dict ,
132
+ )
133
+ if node .op == "call_function" and node .target .__name__ == "map_impl" :
134
+ body_graph , num_mapped , * operands = node .args
135
+ body_gm = getattr (gm , body_graph .name )
136
+ inp_pos_to_buffer_name_for_submod = {}
137
+ real_operands = []
138
+ for ix , operand in enumerate (operands ):
139
+ if operand .target in inp_pos_to_param_buffer_name .values ():
140
+ inp_pos_to_buffer_name_for_submod [ix ] = operand .target
141
+ body_gm .register_buffer (operand .target , state_dict [operand .target ])
142
+ else :
143
+ real_operands .append (operand )
144
+ node .args = (body_graph , num_mapped , * real_operands )
145
+
146
+ _ , in_spec = pytree .tree_flatten (real_operands )
147
+
148
+ _unlift (
149
+ body_gm , inp_pos_to_buffer_name_for_submod , in_spec , None , state_dict
150
+ )
151
+ gm .graph .lint ()
152
+ gm .graph .eliminate_dead_code ()
153
+ gm .recompile ()
154
+ return gm
155
+
156
+
157
+ def unlift_exported_program_lifted_states (
158
+ ep : torch ._export .exported_program .ExportedProgram ,
159
+ ):
160
+ new_gm = copy .deepcopy (ep .graph_module )
161
+
162
+ # TODO Fix the period in params/buffers names later
163
+ # maybe a pass to replace graph signature with fixed names
164
+ param_buffer_name_to_corrected_name = {}
165
+
166
+ for name , stuff in ep .state_dict .items ():
167
+ if name in ep .graph_signature .buffers :
168
+ if "." in name :
169
+ new_gm .register_buffer (name .replace ("." , "_" ), stuff )
170
+ param_buffer_name_to_corrected_name [name ] = name .replace ("." , "_" )
171
+ else :
172
+ new_gm .register_buffer (name , stuff )
173
+ elif name in ep .graph_signature .parameters :
174
+ if "." in name :
175
+ new_gm .register_parameter (name .replace ("." , "_" ), stuff )
176
+ param_buffer_name_to_corrected_name [name ] = name .replace ("." , "_" )
177
+ else :
178
+ new_gm .register_parameter (name , stuff )
179
+ else :
180
+ raise AssertionError ("encountered not registered param/buffer" )
181
+
182
+ count = 0
183
+ inp_pos_to_param_buffer_name = {}
184
+ for node in new_gm .graph .nodes :
185
+ if node .op == "placeholder" :
186
+ if node .name in ep .graph_signature .inputs_to_buffers :
187
+ buffer_name = ep .graph_signature .inputs_to_buffers [node .name ]
188
+ if buffer_name in param_buffer_name_to_corrected_name :
189
+ inp_pos_to_param_buffer_name [
190
+ count
191
+ ] = param_buffer_name_to_corrected_name [buffer_name ]
192
+ else :
193
+ inp_pos_to_param_buffer_name [count ] = buffer_name
194
+ if node .name in ep .graph_signature .inputs_to_parameters :
195
+ param_name = ep .graph_signature .inputs_to_parameters [node .name ]
196
+ if param_name in param_buffer_name_to_corrected_name :
197
+ inp_pos_to_param_buffer_name [
198
+ count
199
+ ] = param_buffer_name_to_corrected_name [param_name ]
200
+ else :
201
+ inp_pos_to_param_buffer_name [count ] = param_name
202
+ count += 1
203
+ new_gm = _unlift (
204
+ new_gm ,
205
+ inp_pos_to_param_buffer_name ,
206
+ ep .call_spec .in_spec ,
207
+ ep .call_spec .out_spec ,
208
+ ep .state_dict ,
209
+ )
210
+ return new_gm
211
+
212
+
65
213
@compatibility (is_backward_compatible = False )
66
214
@dataclass
67
215
class CaptureConfig :
@@ -70,6 +218,7 @@ class CaptureConfig:
70
218
enable_dynamic_shape : bool = False
71
219
enable_aot : bool = False
72
220
_dynamo_config : "ExirDynamoConfig" = ExirDynamoConfig ()
221
+ _unlift : bool = False
73
222
74
223
75
224
@compatibility (is_backward_compatible = False )
@@ -469,8 +618,15 @@ def capture(
469
618
"Functionalization is required for enable_aot." ,
470
619
)
471
620
472
- ep = export (f , args , _add_runtime_assertions = False , constraints = constraints )
473
- return ep # pyre-ignore
621
+ # TODO remove this later
622
+ with patch ("torch._export.DECOMP_TABLE" , _default_decomposition_table ()):
623
+ ep = export (
624
+ f , args , _add_runtime_assertions = False , constraints = constraints
625
+ )
626
+ ep = ep .transform (ReplaceViewOpsWithViewCopyOpsPass ())
627
+ if not config ._unlift :
628
+ return ep # pyre-ignore
629
+ graph_module = unlift_exported_program_lifted_states (ep )
474
630
475
631
elif config .enable_dynamic_shape :
476
632
if not config ._dynamo_config .dynamic_shapes :
0 commit comments