Skip to content

Commit bfadd0c

Browse files
committed
[executorch] Add propagate_input_spec pass and while_loop HOP support
Add a new pass named `propagate_input_spec`, which recursively assigns meta["input_spec"] on placeholder nodes, including in nested control flow submodules. Placeholders that don't correspond to a top-level input are not assigned this meta key. Also, add `while_loop` to `get_control_flow_submodules`. Differential Revision: [D95876986](https://our.internmc.facebook.com/intern/diff/D95876986/) [ghstack-poisoned]
1 parent 4601e90 commit bfadd0c

File tree

7 files changed

+423
-1
lines changed

7 files changed

+423
-1
lines changed

exir/graph_module.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -78,7 +78,7 @@ def get_control_flow_submodules(
7878
) -> List[Tuple[str, torch.fx.GraphModule, torch.fx.Node]]:
7979
"""
8080
Returns a list of submodules used for control flow operations
81-
(torch.ops.higher_order.cond/map/scan) that are in the given toplevel graph (does not look
81+
(torch.ops.higher_order.cond/map/scan/while_loop) that are in the given toplevel graph (does not look
8282
into submodules). Specifically, the returned value is a list containing
8383
tuples of (name of the submodule that's stored in the graph module, the
8484
submodule itself, and the fx node that uses this submodule).
@@ -89,6 +89,7 @@ def get_control_flow_submodules(
8989
torch.ops.higher_order.cond: [1, 2],
9090
torch.ops.higher_order.map_impl: [0],
9191
torch.ops.higher_order.scan: [0], # combine_fn is at arg index 0
92+
torch.ops.higher_order.while_loop: [0, 1],
9293
},
9394
)
9495

exir/passes/BUCK

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -407,6 +407,16 @@ fbcode_target(_kind = runtime.python_library,
407407
],
408408
)
409409

410+
fbcode_target(_kind = runtime.python_library,
411+
name = "propagate_input_spec",
412+
srcs = [
413+
"propagate_input_spec.py",
414+
],
415+
deps = [
416+
"//caffe2:torch",
417+
],
418+
)
419+
410420
fbcode_target(_kind = runtime.python_library,
411421
name = "remove_unused_parameters_pass",
412422
srcs = [
Lines changed: 151 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,151 @@
1+
# This pass exists to propagate input spec metadata down through nested
2+
# submodules. Specifically, metadata for the type of tensor - USER_INPUT, PARAM,
3+
# BUFFER, corresponding to torch.export.graph_signature.InputKind. If the tensor
4+
# is not a direct input in all paths, it's left left as None.
5+
#
6+
# Metadata is stored in the node meta["input_spec"] with a type of
7+
# torch.export.graph_signature.InputSpec or None. It corresponds to the output
8+
# value of the node, and can be a tuple for nodes that return tuples.
9+
#
10+
# After this pass runs, it should be present on all nodes, including arbitrarily
11+
# nested submodules. This may become stale if the graph is mutated, though.
12+
13+
from typing import Any, Sequence
14+
15+
import torch
16+
17+
from torch.export import ExportedProgram
18+
from torch.export.graph_signature import InputSpec
19+
from torch.fx import GraphModule, Node
20+
21+
# Key for node.meta dict.
22+
INPUT_SPEC_KEY = "input_spec"
23+
24+
25+
def propagate_input_spec(ep: ExportedProgram) -> ExportedProgram:
26+
"""
27+
Assign the meta["input_spec"] value for placeholders in the graph, including
28+
placeholder nodes in control flow submodules.
29+
"""
30+
# Clear any stale input_spec metadata before propagating fresh values.
31+
# Passes like duplicate_constant_node copy all metadata (including
32+
# input_spec) to new nodes, leaving stale specs that don't match the
33+
# updated EP signature.
34+
_clear_input_spec_recursive(ep.graph_module)
35+
36+
inputs = {s.arg.name: s for s in ep.graph_signature.input_specs}
37+
_propagate_input_spec_recursive(ep.graph_module, inputs)
38+
39+
40+
def _clear_input_spec_recursive(gm: GraphModule) -> None:
41+
for node in gm.graph.nodes:
42+
if node.op == "placeholder":
43+
node.meta.pop(INPUT_SPEC_KEY, None)
44+
for _, child in gm.named_children():
45+
if isinstance(child, GraphModule):
46+
_clear_input_spec_recursive(child)
47+
48+
49+
def _collect_node_arg_specs(args: Sequence[Any]) -> list[InputSpec | None]:
50+
"""
51+
Retrieve the input spec for each node arg.
52+
"""
53+
return [
54+
n.meta.get(INPUT_SPEC_KEY, None) if hasattr(n, "meta") else None for n in args
55+
]
56+
57+
58+
def _propagate_input_spec_recursive(
59+
gm: GraphModule, inputs: dict[str, InputSpec] | Sequence[InputSpec]
60+
) -> None:
61+
"""
62+
Given a dictionary or list of InputSpecs for graph inputs, propagate the specs
63+
to any nested submodules.
64+
"""
65+
# Submodules don't have graph signatures, so we need to reconstruct the
66+
# placeholder -> spec mapping based on placeholder node order.
67+
if not isinstance(inputs, dict):
68+
input_dict = {}
69+
70+
# This relies on placeholder node order matching graph inputs - but
71+
# this seems to be an implicit contract that pytorch already uses...
72+
for node in gm.graph.nodes:
73+
if node.op == "placeholder":
74+
input_dict[node.target] = inputs[len(input_dict)]
75+
76+
inputs = input_dict
77+
78+
for node in gm.graph.nodes:
79+
if node.op == "placeholder":
80+
_update_placeholder_meta(node, inputs)
81+
elif node.target == torch.ops.higher_order.cond:
82+
_update_cond_meta(node, inputs)
83+
elif node.target == torch.ops.higher_order.map_impl:
84+
_update_map_meta(node, inputs)
85+
elif node.target == torch.ops.higher_order.scan:
86+
_update_scan_meta(node, inputs)
87+
elif node.target == torch.ops.higher_order.while_loop:
88+
_update_while_loop_meta(node, inputs)
89+
90+
91+
def _update_placeholder_meta(node: Node, inputs: dict[str, InputSpec]) -> None:
92+
spec = inputs.get(node.target, None)
93+
94+
if spec is not None:
95+
node.meta[INPUT_SPEC_KEY] = spec
96+
else:
97+
node.meta.pop(INPUT_SPEC_KEY, None)
98+
99+
100+
def _update_cond_meta(node: Node, inputs: dict[str, InputSpec]) -> None:
101+
_, true_submodule_node, false_submodule_node, submodule_inputs = node.args
102+
submodule_input_specs = _collect_node_arg_specs(submodule_inputs)
103+
104+
# Resolve get_attr nodes to actual submodules
105+
gm = node.graph.owning_module
106+
true_submodule = getattr(gm, true_submodule_node.target)
107+
false_submodule = getattr(gm, false_submodule_node.target)
108+
109+
_propagate_input_spec_recursive(true_submodule, submodule_input_specs)
110+
_propagate_input_spec_recursive(false_submodule, submodule_input_specs)
111+
112+
113+
def _update_map_meta(node: Node, inputs: dict[str, InputSpec]) -> None:
114+
f_node, mapped_args, operands = node.args
115+
mapped_arg_specs = _collect_node_arg_specs(mapped_args)
116+
operand_specs = _collect_node_arg_specs(operands)
117+
submodule_input_specs = mapped_arg_specs + operand_specs
118+
119+
# Resolve get_attr node to actual submodule
120+
gm = node.graph.owning_module
121+
f = getattr(gm, f_node.target)
122+
123+
_propagate_input_spec_recursive(f, submodule_input_specs)
124+
125+
126+
def _update_while_loop_meta(node: Node, inputs: dict[str, InputSpec]) -> None:
127+
cond_fn_node, body_fn_node, carried_inputs, additional_inputs = node.args
128+
carried_specs = _collect_node_arg_specs(carried_inputs)
129+
additional_specs = _collect_node_arg_specs(additional_inputs)
130+
submodule_input_specs = carried_specs + additional_specs
131+
132+
gm = node.graph.owning_module
133+
cond_fn = getattr(gm, cond_fn_node.target)
134+
body_fn = getattr(gm, body_fn_node.target)
135+
136+
_propagate_input_spec_recursive(cond_fn, submodule_input_specs)
137+
_propagate_input_spec_recursive(body_fn, submodule_input_specs)
138+
139+
140+
def _update_scan_meta(node: Node, inputs: dict[str, InputSpec]) -> None:
141+
combine_fn_node, init, xs, additional_inputs = node.args
142+
init_specs = _collect_node_arg_specs(init)
143+
xs_specs = _collect_node_arg_specs(xs)
144+
additional_input_specs = _collect_node_arg_specs(additional_inputs)
145+
submodule_input_specs = init_specs + xs_specs + additional_input_specs
146+
147+
# Resolve get_attr node to actual submodule
148+
gm = node.graph.owning_module
149+
combine_fn = getattr(gm, combine_fn_node.target)
150+
151+
_propagate_input_spec_recursive(combine_fn, submodule_input_specs)

exir/program/BUCK

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,7 @@ fbcode_target(_kind = runtime.python_library,
4040
"//executorch/exir/passes:insert_write_back_for_buffers_pass",
4141
"//executorch/exir/passes:lib",
4242
"//executorch/exir/passes:normalize_view_copy_base_pass",
43+
"//executorch/exir/passes:propagate_input_spec",
4344
"//executorch/exir/passes:remove_graph_asserts_pass",
4445
"//executorch/exir/passes:remove_mixed_type_operators",
4546
"//executorch/exir/passes:replace_aten_with_edge_pass",

exir/program/_program.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -59,6 +59,7 @@
5959
from executorch.exir.passes.normalize_view_copy_base_pass import (
6060
NormalizeViewCopyBasePass,
6161
)
62+
from executorch.exir.passes.propagate_input_spec import propagate_input_spec
6263
from executorch.exir.passes.quant_fusion_pass import quant_fusion_and_const_prop_pass
6364
from executorch.exir.passes.reinplace import reinplace_pass
6465
from executorch.exir.passes.remove_graph_asserts_pass import (
@@ -912,6 +913,9 @@ def _generate_edge_program(
912913
],
913914
)
914915

916+
# Recursively tag placeholder nodes in submodules with input specs.
917+
propagate_input_spec(edge_program)
918+
915919
# Lift the tensor constants created in ScalarToTensorPass
916920
edge_program = lift_constant_tensor_pass(edge_program)
917921

@@ -1232,6 +1236,11 @@ def _gen_edge_manager_for_partitioners(
12321236
# First pass of decompositions with this partitioner's preserved ops
12331237
program = program.run_decompositions(table)
12341238

1239+
# Propagate input specs so that check_constraints
1240+
# can identify parameter nodes inside control flow
1241+
# submodules (e.g. cond/map/scan branches).
1242+
propagate_input_spec(program)
1243+
12351244
# Filter ops using EDGE_DO_NOT_DECOMP
12361245
temp_partitioner_dict = {name: [curr_partitioner]}
12371246
preserved_ops = (

exir/tests/__init__.py

Whitespace-only changes.

0 commit comments

Comments
 (0)