Skip to content

Add unlifting pass under private config #4

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 4 additions & 7 deletions backends/test/test_backends.py
Original file line number Diff line number Diff line change
Expand Up @@ -618,11 +618,7 @@ def forward(self, x_raw, h, c):
).to_edge(exir.EdgeCompileConfig(_check_ir_validity=False))

program_without_delegates = (
exir.capture(
composite_m,
(input_x, input_h, input_c),
exir.CaptureConfig(pt2_mode=True),
)
exir.capture(CompositeModel(3), inputs)
.to_edge(exir.EdgeCompileConfig(_check_ir_validity=False))
.to_executorch(
config=exir.ExecutorchBackendConfig(extract_segments=extract_segments),
Expand Down Expand Up @@ -726,7 +722,7 @@ def forward(self, x_raw, h, c):

program_without_delegates = (
exir.capture(
composite_m,
CompositeModel(3),
(input_x, input_h, input_c),
exir.CaptureConfig(pt2_mode=True),
)
Expand Down Expand Up @@ -962,7 +958,8 @@ def test_quantized_with_delegate(self) -> None:
example_inputs,
exir.CaptureConfig(
pt2_mode=True,
enable_functionalization=False,
enable_aot=True,
_unlift=True,
),
).to_edge(exir.EdgeCompileConfig(_check_ir_validity=False))
FileCheck().check_count("quantize_per_tensor.default", 3).check("addmm").run(
Expand Down
160 changes: 158 additions & 2 deletions exir/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,11 @@
from collections import namedtuple
from dataclasses import dataclass, field
from typing import Any, Callable, Dict, List, Optional, Set, Tuple, Union
from unittest.mock import patch

import sympy
import torch
import torch._export
from executorch.exir.dynamic_shape import DynamicMemoryPlanningMode
from executorch.exir.emit import emit_program, EmitterOutput
from executorch.exir.error import ExportError, ExportErrorType, InternalError
Expand All @@ -25,6 +27,7 @@
from executorch.exir.schema import Program
from executorch.exir.serialize import serialize_to_flatbuffer
from executorch.exir.tracer import (
_default_decomposition_table,
dispatch_trace,
dynamo_trace,
ExirDynamoConfig,
Expand All @@ -41,6 +44,7 @@
from torch._dynamo.eval_frame import Constraint
from torch._export import CallSpec, export, ExportGraphSignature
from torch._export.exported_program import ExportedProgram
from torch._export.passes import ReplaceViewOpsWithViewCopyOpsPass
from torch._export.passes.add_runtime_assertions_for_constraints_pass import (
InputDim,
RangeConstraint,
Expand All @@ -49,12 +53,156 @@
from torch.fx._compatibility import compatibility
from torch.fx.experimental.proxy_tensor import make_fx
from torch.fx.experimental.symbolic_shapes import ShapeEnv
from torch.fx.graph import _PyTreeCodeGen, _PyTreeInfo
from torch.utils import _pytree as pytree


Val = Any


def _unlift(gm, inp_pos_to_param_buffer_name, in_spec, out_spec, state_dict):
count = 0
# Step 1: make lifted params as get_attr
for node in gm.graph.nodes:
if node.op == "placeholder":
if count in inp_pos_to_param_buffer_name:
with gm.graph.inserting_after(node):
getattr_node = gm.graph.get_attr(
inp_pos_to_param_buffer_name[count]
)
node.replace_all_uses_with(getattr_node)
metadata = node.meta
gm.graph.erase_node(node)
getattr_node.meta = metadata
count += 1

# Step 2: Fix the input/output of the graph now that we deleted
# some args.
gm.graph.lint()
names = [f"arg_{i}" for i in range(len(in_spec.children_specs))]
gm.graph._codegen = _PyTreeCodeGen(
_PyTreeInfo(
names,
in_spec,
out_spec,
)
)
gm.recompile()

# Step 3: Find state references in HigherOrderOps and recursively
# fix them.
for node in gm.graph.nodes:
if node.op == "call_function" and node.target == torch.ops.cond:
pred, true_graph, false_graph, operands = node.args
true_gm = getattr(gm, true_graph.name)
false_gm = getattr(gm, false_graph.name)
inp_pos_to_param_buffer_name_for_submod = {}
real_operands = []
for ix, operand in enumerate(operands):
if operand.target in inp_pos_to_param_buffer_name.values():
inp_pos_to_param_buffer_name_for_submod[ix] = operand.target
true_gm.register_buffer(operand.target, state_dict[operand.target])
false_gm.register_buffer(operand.target, state_dict[operand.target])
else:
real_operands.append(operand)
node.args = (pred, true_graph, false_graph, real_operands)

_, in_spec = pytree.tree_flatten(real_operands)

_unlift(
true_gm,
inp_pos_to_param_buffer_name_for_submod,
in_spec,
None,
state_dict,
)
_unlift(
false_gm,
inp_pos_to_param_buffer_name_for_submod,
in_spec,
None,
state_dict,
)
if node.op == "call_function" and node.target.__name__ == "map_impl":
body_graph, num_mapped, *operands = node.args
body_gm = getattr(gm, body_graph.name)
inp_pos_to_buffer_name_for_submod = {}
real_operands = []
for ix, operand in enumerate(operands):
if operand.target in inp_pos_to_param_buffer_name.values():
inp_pos_to_buffer_name_for_submod[ix] = operand.target
body_gm.register_buffer(operand.target, state_dict[operand.target])
else:
real_operands.append(operand)
node.args = (body_graph, num_mapped, *real_operands)

_, in_spec = pytree.tree_flatten(real_operands)

_unlift(
body_gm, inp_pos_to_buffer_name_for_submod, in_spec, None, state_dict
)
gm.graph.lint()
gm.graph.eliminate_dead_code()
gm.recompile()
return gm


def unlift_exported_program_lifted_states(
ep: torch._export.exported_program.ExportedProgram,
):
new_gm = copy.deepcopy(ep.graph_module)

# TODO Fix the period in params/buffers names later
# maybe a pass to replace graph signature with fixed names
param_buffer_name_to_corrected_name = {}

for name, stuff in ep.state_dict.items():
if name in ep.graph_signature.buffers:
if "." in name:
new_gm.register_buffer(name.replace(".", "_"), stuff)
param_buffer_name_to_corrected_name[name] = name.replace(".", "_")
else:
new_gm.register_buffer(name, stuff)
elif name in ep.graph_signature.parameters:
if "." in name:
new_gm.register_parameter(name.replace(".", "_"), stuff)
param_buffer_name_to_corrected_name[name] = name.replace(".", "_")
else:
new_gm.register_parameter(name, stuff)
else:
raise AssertionError("encountered not registered param/buffer")

count = 0
inp_pos_to_param_buffer_name = {}
for node in new_gm.graph.nodes:
if node.op == "placeholder":
if node.name in ep.graph_signature.inputs_to_buffers:
buffer_name = ep.graph_signature.inputs_to_buffers[node.name]
if buffer_name in param_buffer_name_to_corrected_name:
inp_pos_to_param_buffer_name[
count
] = param_buffer_name_to_corrected_name[buffer_name]
else:
inp_pos_to_param_buffer_name[count] = buffer_name
if node.name in ep.graph_signature.inputs_to_parameters:
param_name = ep.graph_signature.inputs_to_parameters[node.name]
if param_name in param_buffer_name_to_corrected_name:
inp_pos_to_param_buffer_name[
count
] = param_buffer_name_to_corrected_name[param_name]
else:
inp_pos_to_param_buffer_name[count] = param_name
count += 1
new_gm = _unlift(
new_gm,
inp_pos_to_param_buffer_name,
ep.call_spec.in_spec,
ep.call_spec.out_spec,
ep.state_dict,
)
return new_gm


@compatibility(is_backward_compatible=False)
@dataclass
class CaptureConfig:
Expand All @@ -63,6 +211,7 @@ class CaptureConfig:
enable_dynamic_shape: bool = False
enable_aot: bool = False
_dynamo_config: "ExirDynamoConfig" = ExirDynamoConfig()
_unlift: bool = False


@compatibility(is_backward_compatible=False)
Expand Down Expand Up @@ -400,8 +549,15 @@ def capture(
"Functionalization is required for enable_aot.",
)

ep = export(f, args, _add_runtime_assertions=False, constraints=constraints)
return ep # pyre-ignore
# TODO remove this later
with patch("torch._export.DECOMP_TABLE", _default_decomposition_table()):
ep = export(
f, args, _add_runtime_assertions=False, constraints=constraints
)
ep = ep.transform(ReplaceViewOpsWithViewCopyOpsPass())
if not config._unlift:
return ep # pyre-ignore
graph_module = unlift_exported_program_lifted_states(ep)

elif config.enable_dynamic_shape:
if not config._dynamo_config.dynamic_shapes:
Expand Down
82 changes: 58 additions & 24 deletions exir/dialects/edge/edge.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,14 @@
mat2: T0
__ret_0: T0

- func: aten::arange.start_step
namespace: edge
inherits: aten::arange.start_step
type_alias:
T0: [Byte, Char, Double, Float, Int, Long, Short]
type_constraint:
- __ret_0: T0

- func: aten::bmm
namespace: edge
inherits: aten::bmm
Expand Down Expand Up @@ -198,14 +206,43 @@
- self: T0
__ret_0: T0

- func: aten::lift_fresh_copy
- func: aten::index_select
namespace: edge
inherits: aten::lift_fresh_copy
inherits: aten::index_select
type_alias:
T0: [Bool, Byte, Char, Double, Float, Int, Long, Short]
T0: [Bool]
T1: [Byte]
T2: [Char]
T3: [Double]
T4: [Float]
T5: [Int]
T6: [Long]
T7: [Short]
type_constraint:
- self: T0
index: T6
__ret_0: T0
- self: T1
index: T6
__ret_0: T1
- self: T2
index: T6
__ret_0: T2
- self: T3
index: T6
__ret_0: T3
- self: T4
index: T6
__ret_0: T4
- self: T5
index: T6
__ret_0: T5
- self: T6
index: T6
__ret_0: T6
- self: T7
index: T6
__ret_0: T7

- func: aten::masked_fill.Scalar
namespace: edge
Expand Down Expand Up @@ -245,16 +282,6 @@
mask: T0
__ret_0: T7

- func: aten::minimum
namespace: edge
inherits: aten::minimum
type_alias:
T0: [Bool, Byte, Char, Double, Float, Int, Long, Short]
type_constraint:
- self: T0
other: T0
__ret_0: T0

- func: aten::mm
namespace: edge
inherits: aten::mm
Expand Down Expand Up @@ -324,15 +351,6 @@
- self: T0
__ret_0: T0

- func: aten::select_copy.int
namespace: edge
inherits: aten::select_copy.int
type_alias:
T0: [Bool, Byte, Char, Double, Float, Int, Long, Short]
type_constraint:
- self: T0
__ret_0: T0

- func: aten::sigmoid
namespace: edge
inherits: aten::sigmoid
Expand Down Expand Up @@ -383,9 +401,25 @@
other: T0
__ret_0: T0

- func: aten::t
- func: aten::sym_numel
namespace: edge
inherits: aten::sym_numel
type_alias:
T0: [Bool, Byte, Char, Double, Float, Int, Long, Short]
type_constraint:
- self: T0

- func: aten::sym_size.int
namespace: edge
inherits: aten::sym_size.int
type_alias:
T0: [Bool, Byte, Char, Double, Float, Int, Long, Short]
type_constraint:
- self: T0

- func: aten::t_copy
namespace: edge
inherits: aten::t
inherits: aten::t_copy
type_alias:
T0: [Bool, Byte, Char, Double, Float, Int, Long, Short]
type_constraint:
Expand Down
2 changes: 2 additions & 0 deletions exir/dialects/edge/yaml_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -143,6 +143,8 @@ def get_test_gen_key(op_name: str) -> str:
opdb_key = opdb_key[:-5]
elif opdb_key == "sym_size":
opdb_key = "resize_"
elif opdb_key == "sym_numel":
opdb_key = "abs"
elif opdb_key == "convolution":
opdb_key = "conv_transpose2d"
elif opdb_key == "embedding":
Expand Down
Loading