Skip to content

Commit e6b8cc2

Browse files
tugsbayasgalanfacebook-github-bot
authored andcommitted
Add unlifting pass under private config (#4)
Summary: X-link: pytorch/pytorch#104897 Pull Request resolved: #4 We wanna do this little by little. For now, I tried only on DissectedPartsModel which needs to use aot_export version. Differential Revision: D46785735 fbshipit-source-id: 4225a68675315e99e23bd65be70819dc5feb13f0
1 parent 561f4c5 commit e6b8cc2

File tree

6 files changed

+360
-48
lines changed

6 files changed

+360
-48
lines changed

backends/test/test_backends.py

Lines changed: 5 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -617,17 +617,13 @@ def forward(self, x_raw, h, c):
617617
orig_res = composite_m(*inputs)
618618

619619
traced = (
620-
exir.capture(composite_m, inputs, exir.CaptureConfig(pt2_mode=True))
620+
exir.capture(composite_m, inputs)
621621
.to_edge(exir.EdgeCompileConfig(_check_ir_validity=False))
622622
.graph_module
623623
)
624624

625625
program_without_delegates = (
626-
exir.capture(
627-
composite_m,
628-
(input_x, input_h, input_c),
629-
exir.CaptureConfig(pt2_mode=True),
630-
)
626+
exir.capture(CompositeModel(3), inputs)
631627
.to_edge(exir.EdgeCompileConfig(_check_ir_validity=False))
632628
.to_executorch(
633629
config=exir.ExecutorchBackendConfig(extract_segments=extract_segments),
@@ -732,7 +728,7 @@ def forward(self, x_raw, h, c):
732728

733729
program_without_delegates = (
734730
exir.capture(
735-
composite_m,
731+
CompositeModel(3),
736732
(input_x, input_h, input_c),
737733
exir.CaptureConfig(pt2_mode=True),
738734
)
@@ -983,7 +979,8 @@ def test_quantized_with_delegate(self) -> None:
983979
example_inputs,
984980
exir.CaptureConfig(
985981
pt2_mode=True,
986-
enable_functionalization=False,
982+
enable_aot=True,
983+
_unlift=True,
987984
),
988985
).to_edge(exir.EdgeCompileConfig(_check_ir_validity=False))
989986
FileCheck().check_count("quantize_per_tensor.default", 3).check("addmm").run(

exir/__init__.py

Lines changed: 158 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,9 +5,11 @@
55
from collections import namedtuple
66
from dataclasses import dataclass, field
77
from typing import Any, Callable, Dict, List, Optional, Set, Tuple, Union
8+
from unittest.mock import patch
89

910
import sympy
1011
import torch
12+
import torch._export
1113
from executorch.exir.dynamic_shape import DynamicMemoryPlanningMode
1214
from executorch.exir.emit import emit_program, EmitterOutput
1315
from executorch.exir.error import ExportError, ExportErrorType, InternalError
@@ -32,6 +34,7 @@
3234
from executorch.exir.schema import Program
3335
from executorch.exir.serialize import serialize_to_flatbuffer
3436
from executorch.exir.tracer import (
37+
_default_decomposition_table,
3538
dispatch_trace,
3639
dynamo_trace,
3740
ExirDynamoConfig,
@@ -48,6 +51,7 @@
4851
from torch._dynamo.eval_frame import Constraint
4952
from torch._export import CallSpec, export, ExportGraphSignature
5053
from torch._export.exported_program import ExportedProgram
54+
from torch._export.passes import ReplaceViewOpsWithViewCopyOpsPass
5155
from torch._export.passes.add_runtime_assertions_for_constraints_pass import (
5256
InputDim,
5357
RangeConstraint,
@@ -56,12 +60,156 @@
5660
from torch.fx._compatibility import compatibility
5761
from torch.fx.experimental.proxy_tensor import make_fx
5862
from torch.fx.experimental.symbolic_shapes import ShapeEnv
63+
from torch.fx.graph import _PyTreeCodeGen, _PyTreeInfo
5964
from torch.utils import _pytree as pytree
6065

6166

6267
Val = Any
6368

6469

70+
def _unlift(gm, inp_pos_to_param_buffer_name, in_spec, out_spec, state_dict):
71+
count = 0
72+
# Step 1: make lifted params as get_attr
73+
for node in gm.graph.nodes:
74+
if node.op == "placeholder":
75+
if count in inp_pos_to_param_buffer_name:
76+
with gm.graph.inserting_after(node):
77+
getattr_node = gm.graph.get_attr(
78+
inp_pos_to_param_buffer_name[count]
79+
)
80+
node.replace_all_uses_with(getattr_node)
81+
metadata = node.meta
82+
gm.graph.erase_node(node)
83+
getattr_node.meta = metadata
84+
count += 1
85+
86+
# Step 2: Fix the input/output of the graph now that we deleted
87+
# some args.
88+
gm.graph.lint()
89+
names = [f"arg_{i}" for i in range(len(in_spec.children_specs))]
90+
gm.graph._codegen = _PyTreeCodeGen(
91+
_PyTreeInfo(
92+
names,
93+
in_spec,
94+
out_spec,
95+
)
96+
)
97+
gm.recompile()
98+
99+
# Step 3: Find state references in HigherOrderOps and recursively
100+
# fix them.
101+
for node in gm.graph.nodes:
102+
if node.op == "call_function" and node.target == torch.ops.cond:
103+
pred, true_graph, false_graph, operands = node.args
104+
true_gm = getattr(gm, true_graph.name)
105+
false_gm = getattr(gm, false_graph.name)
106+
inp_pos_to_param_buffer_name_for_submod = {}
107+
real_operands = []
108+
for ix, operand in enumerate(operands):
109+
if operand.target in inp_pos_to_param_buffer_name.values():
110+
inp_pos_to_param_buffer_name_for_submod[ix] = operand.target
111+
true_gm.register_buffer(operand.target, state_dict[operand.target])
112+
false_gm.register_buffer(operand.target, state_dict[operand.target])
113+
else:
114+
real_operands.append(operand)
115+
node.args = (pred, true_graph, false_graph, real_operands)
116+
117+
_, in_spec = pytree.tree_flatten(real_operands)
118+
119+
_unlift(
120+
true_gm,
121+
inp_pos_to_param_buffer_name_for_submod,
122+
in_spec,
123+
None,
124+
state_dict,
125+
)
126+
_unlift(
127+
false_gm,
128+
inp_pos_to_param_buffer_name_for_submod,
129+
in_spec,
130+
None,
131+
state_dict,
132+
)
133+
if node.op == "call_function" and node.target.__name__ == "map_impl":
134+
body_graph, num_mapped, *operands = node.args
135+
body_gm = getattr(gm, body_graph.name)
136+
inp_pos_to_buffer_name_for_submod = {}
137+
real_operands = []
138+
for ix, operand in enumerate(operands):
139+
if operand.target in inp_pos_to_param_buffer_name.values():
140+
inp_pos_to_buffer_name_for_submod[ix] = operand.target
141+
body_gm.register_buffer(operand.target, state_dict[operand.target])
142+
else:
143+
real_operands.append(operand)
144+
node.args = (body_graph, num_mapped, *real_operands)
145+
146+
_, in_spec = pytree.tree_flatten(real_operands)
147+
148+
_unlift(
149+
body_gm, inp_pos_to_buffer_name_for_submod, in_spec, None, state_dict
150+
)
151+
gm.graph.lint()
152+
gm.graph.eliminate_dead_code()
153+
gm.recompile()
154+
return gm
155+
156+
157+
def unlift_exported_program_lifted_states(
158+
ep: torch._export.exported_program.ExportedProgram,
159+
):
160+
new_gm = copy.deepcopy(ep.graph_module)
161+
162+
# TODO Fix the period in params/buffers names later
163+
# maybe a pass to replace graph signature with fixed names
164+
param_buffer_name_to_corrected_name = {}
165+
166+
for name, stuff in ep.state_dict.items():
167+
if name in ep.graph_signature.buffers:
168+
if "." in name:
169+
new_gm.register_buffer(name.replace(".", "_"), stuff)
170+
param_buffer_name_to_corrected_name[name] = name.replace(".", "_")
171+
else:
172+
new_gm.register_buffer(name, stuff)
173+
elif name in ep.graph_signature.parameters:
174+
if "." in name:
175+
new_gm.register_parameter(name.replace(".", "_"), stuff)
176+
param_buffer_name_to_corrected_name[name] = name.replace(".", "_")
177+
else:
178+
new_gm.register_parameter(name, stuff)
179+
else:
180+
raise AssertionError("encountered not registered param/buffer")
181+
182+
count = 0
183+
inp_pos_to_param_buffer_name = {}
184+
for node in new_gm.graph.nodes:
185+
if node.op == "placeholder":
186+
if node.name in ep.graph_signature.inputs_to_buffers:
187+
buffer_name = ep.graph_signature.inputs_to_buffers[node.name]
188+
if buffer_name in param_buffer_name_to_corrected_name:
189+
inp_pos_to_param_buffer_name[
190+
count
191+
] = param_buffer_name_to_corrected_name[buffer_name]
192+
else:
193+
inp_pos_to_param_buffer_name[count] = buffer_name
194+
if node.name in ep.graph_signature.inputs_to_parameters:
195+
param_name = ep.graph_signature.inputs_to_parameters[node.name]
196+
if param_name in param_buffer_name_to_corrected_name:
197+
inp_pos_to_param_buffer_name[
198+
count
199+
] = param_buffer_name_to_corrected_name[param_name]
200+
else:
201+
inp_pos_to_param_buffer_name[count] = param_name
202+
count += 1
203+
new_gm = _unlift(
204+
new_gm,
205+
inp_pos_to_param_buffer_name,
206+
ep.call_spec.in_spec,
207+
ep.call_spec.out_spec,
208+
ep.state_dict,
209+
)
210+
return new_gm
211+
212+
65213
@compatibility(is_backward_compatible=False)
66214
@dataclass
67215
class CaptureConfig:
@@ -70,6 +218,7 @@ class CaptureConfig:
70218
enable_dynamic_shape: bool = False
71219
enable_aot: bool = False
72220
_dynamo_config: "ExirDynamoConfig" = ExirDynamoConfig()
221+
_unlift: bool = False
73222

74223

75224
@compatibility(is_backward_compatible=False)
@@ -469,8 +618,15 @@ def capture(
469618
"Functionalization is required for enable_aot.",
470619
)
471620

472-
ep = export(f, args, _add_runtime_assertions=False, constraints=constraints)
473-
return ep # pyre-ignore
621+
# TODO remove this later
622+
with patch("torch._export.DECOMP_TABLE", _default_decomposition_table()):
623+
ep = export(
624+
f, args, _add_runtime_assertions=False, constraints=constraints
625+
)
626+
ep = ep.transform(ReplaceViewOpsWithViewCopyOpsPass())
627+
if not config._unlift:
628+
return ep # pyre-ignore
629+
graph_module = unlift_exported_program_lifted_states(ep)
474630

475631
elif config.enable_dynamic_shape:
476632
if not config._dynamo_config.dynamic_shapes:

exir/dialects/edge/edge.yaml

Lines changed: 58 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -89,6 +89,14 @@
8989
mat2: T0
9090
__ret_0: T0
9191

92+
- func: aten::arange.start_step
93+
namespace: edge
94+
inherits: aten::arange.start_step
95+
type_alias:
96+
T0: [Byte, Char, Double, Float, Int, Long, Short]
97+
type_constraint:
98+
- __ret_0: T0
99+
92100
- func: aten::bmm
93101
namespace: edge
94102
inherits: aten::bmm
@@ -198,14 +206,43 @@
198206
- self: T0
199207
__ret_0: T0
200208

201-
- func: aten::lift_fresh_copy
209+
- func: aten::index_select
202210
namespace: edge
203-
inherits: aten::lift_fresh_copy
211+
inherits: aten::index_select
204212
type_alias:
205-
T0: [Bool, Byte, Char, Double, Float, Int, Long, Short]
213+
T0: [Bool]
214+
T1: [Byte]
215+
T2: [Char]
216+
T3: [Double]
217+
T4: [Float]
218+
T5: [Int]
219+
T6: [Long]
220+
T7: [Short]
206221
type_constraint:
207222
- self: T0
223+
index: T6
208224
__ret_0: T0
225+
- self: T1
226+
index: T6
227+
__ret_0: T1
228+
- self: T2
229+
index: T6
230+
__ret_0: T2
231+
- self: T3
232+
index: T6
233+
__ret_0: T3
234+
- self: T4
235+
index: T6
236+
__ret_0: T4
237+
- self: T5
238+
index: T6
239+
__ret_0: T5
240+
- self: T6
241+
index: T6
242+
__ret_0: T6
243+
- self: T7
244+
index: T6
245+
__ret_0: T7
209246

210247
- func: aten::masked_fill.Scalar
211248
namespace: edge
@@ -245,16 +282,6 @@
245282
mask: T0
246283
__ret_0: T7
247284

248-
- func: aten::minimum
249-
namespace: edge
250-
inherits: aten::minimum
251-
type_alias:
252-
T0: [Bool, Byte, Char, Double, Float, Int, Long, Short]
253-
type_constraint:
254-
- self: T0
255-
other: T0
256-
__ret_0: T0
257-
258285
- func: aten::mm
259286
namespace: edge
260287
inherits: aten::mm
@@ -324,15 +351,6 @@
324351
- self: T0
325352
__ret_0: T0
326353

327-
- func: aten::select_copy.int
328-
namespace: edge
329-
inherits: aten::select_copy.int
330-
type_alias:
331-
T0: [Bool, Byte, Char, Double, Float, Int, Long, Short]
332-
type_constraint:
333-
- self: T0
334-
__ret_0: T0
335-
336354
- func: aten::sigmoid
337355
namespace: edge
338356
inherits: aten::sigmoid
@@ -383,9 +401,25 @@
383401
other: T0
384402
__ret_0: T0
385403

386-
- func: aten::t
404+
- func: aten::sym_numel
405+
namespace: edge
406+
inherits: aten::sym_numel
407+
type_alias:
408+
T0: [Bool, Byte, Char, Double, Float, Int, Long, Short]
409+
type_constraint:
410+
- self: T0
411+
412+
- func: aten::sym_size.int
413+
namespace: edge
414+
inherits: aten::sym_size.int
415+
type_alias:
416+
T0: [Bool, Byte, Char, Double, Float, Int, Long, Short]
417+
type_constraint:
418+
- self: T0
419+
420+
- func: aten::t_copy
387421
namespace: edge
388-
inherits: aten::t
422+
inherits: aten::t_copy
389423
type_alias:
390424
T0: [Bool, Byte, Char, Double, Float, Int, Long, Short]
391425
type_constraint:

exir/dialects/edge/yaml_generator.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -143,6 +143,8 @@ def get_test_gen_key(op_name: str) -> str:
143143
opdb_key = opdb_key[:-5]
144144
elif opdb_key == "sym_size":
145145
opdb_key = "resize_"
146+
elif opdb_key == "sym_numel":
147+
opdb_key = "abs"
146148
elif opdb_key == "convolution":
147149
opdb_key = "conv_transpose2d"
148150
elif opdb_key == "embedding":

0 commit comments

Comments
 (0)