|
| 1 | +# This pass exists to propagate input spec metadata down through nested |
| 2 | +# submodules. Specifically, metadata for the type of tensor - USER_INPUT, PARAM, |
| 3 | +# BUFFER, corresponding to torch.export.graph_signature.InputKind. If the tensor |
| 4 | +# is not a direct input in all paths, it's left left as None. |
| 5 | +# |
| 6 | +# Metadata is stored in the node meta["input_spec"] with a type of |
| 7 | +# torch.export.graph_signature.InputSpec or None. It corresponds to the output |
| 8 | +# value of the node, and can be a tuple for nodes that return tuples. |
| 9 | +# |
| 10 | +# After this pass runs, it should be present on all nodes, including arbitrarily |
| 11 | +# nested submodules. This may become stale if the graph is mutated, though. |
| 12 | + |
| 13 | +from typing import Any, Sequence |
| 14 | + |
| 15 | +import torch |
| 16 | + |
| 17 | +from torch.export import ExportedProgram |
| 18 | +from torch.export.graph_signature import InputSpec |
| 19 | +from torch.fx import GraphModule, Node |
| 20 | + |
| 21 | +# Key for node.meta dict. |
| 22 | +INPUT_SPEC_KEY = "input_spec" |
| 23 | + |
| 24 | + |
| 25 | +def propagate_input_spec(ep: ExportedProgram) -> ExportedProgram: |
| 26 | + """ |
| 27 | + Assign the meta["input_spec"] value for placeholders in the graph, including |
| 28 | + placeholder nodes in control flow submodules. |
| 29 | + """ |
| 30 | + # Clear any stale input_spec metadata before propagating fresh values. |
| 31 | + # Passes like duplicate_constant_node copy all metadata (including |
| 32 | + # input_spec) to new nodes, leaving stale specs that don't match the |
| 33 | + # updated EP signature. |
| 34 | + _clear_input_spec_recursive(ep.graph_module) |
| 35 | + |
| 36 | + inputs = {s.arg.name: s for s in ep.graph_signature.input_specs} |
| 37 | + _propagate_input_spec_recursive(ep.graph_module, inputs) |
| 38 | + |
| 39 | + |
| 40 | +def _clear_input_spec_recursive(gm: GraphModule) -> None: |
| 41 | + for node in gm.graph.nodes: |
| 42 | + if node.op == "placeholder": |
| 43 | + node.meta.pop(INPUT_SPEC_KEY, None) |
| 44 | + for _, child in gm.named_children(): |
| 45 | + if isinstance(child, GraphModule): |
| 46 | + _clear_input_spec_recursive(child) |
| 47 | + |
| 48 | + |
| 49 | +def _collect_node_arg_specs(args: Sequence[Any]) -> list[InputSpec | None]: |
| 50 | + """ |
| 51 | + Retrieve the input spec for each node arg. |
| 52 | + """ |
| 53 | + return [ |
| 54 | + n.meta.get(INPUT_SPEC_KEY, None) if hasattr(n, "meta") else None for n in args |
| 55 | + ] |
| 56 | + |
| 57 | + |
| 58 | +def _propagate_input_spec_recursive( |
| 59 | + gm: GraphModule, inputs: dict[str, InputSpec] | Sequence[InputSpec] |
| 60 | +) -> None: |
| 61 | + """ |
| 62 | + Given a dictionary or list of InputSpecs for graph inputs, propagate the specs |
| 63 | + to any nested submodules. |
| 64 | + """ |
| 65 | + # Submodules don't have graph signatures, so we need to reconstruct the |
| 66 | + # placeholder -> spec mapping based on placeholder node order. |
| 67 | + if not isinstance(inputs, dict): |
| 68 | + input_dict = {} |
| 69 | + |
| 70 | + # This relies on placeholder node order matching graph inputs - but |
| 71 | + # this seems to be an implicit contract that pytorch already uses... |
| 72 | + for node in gm.graph.nodes: |
| 73 | + if node.op == "placeholder": |
| 74 | + input_dict[node.target] = inputs[len(input_dict)] |
| 75 | + |
| 76 | + inputs = input_dict |
| 77 | + |
| 78 | + for node in gm.graph.nodes: |
| 79 | + if node.op == "placeholder": |
| 80 | + _update_placeholder_meta(node, inputs) |
| 81 | + elif node.target == torch.ops.higher_order.cond: |
| 82 | + _update_cond_meta(node, inputs) |
| 83 | + elif node.target == torch.ops.higher_order.map_impl: |
| 84 | + _update_map_meta(node, inputs) |
| 85 | + elif node.target == torch.ops.higher_order.scan: |
| 86 | + _update_scan_meta(node, inputs) |
| 87 | + elif node.target == torch.ops.higher_order.while_loop: |
| 88 | + _update_while_loop_meta(node, inputs) |
| 89 | + |
| 90 | + |
| 91 | +def _update_placeholder_meta(node: Node, inputs: dict[str, InputSpec]) -> None: |
| 92 | + spec = inputs.get(node.target, None) |
| 93 | + |
| 94 | + if spec is not None: |
| 95 | + node.meta[INPUT_SPEC_KEY] = spec |
| 96 | + else: |
| 97 | + node.meta.pop(INPUT_SPEC_KEY, None) |
| 98 | + |
| 99 | + |
| 100 | +def _update_cond_meta(node: Node, inputs: dict[str, InputSpec]) -> None: |
| 101 | + _, true_submodule_node, false_submodule_node, submodule_inputs = node.args |
| 102 | + submodule_input_specs = _collect_node_arg_specs(submodule_inputs) |
| 103 | + |
| 104 | + # Resolve get_attr nodes to actual submodules |
| 105 | + gm = node.graph.owning_module |
| 106 | + true_submodule = getattr(gm, true_submodule_node.target) |
| 107 | + false_submodule = getattr(gm, false_submodule_node.target) |
| 108 | + |
| 109 | + _propagate_input_spec_recursive(true_submodule, submodule_input_specs) |
| 110 | + _propagate_input_spec_recursive(false_submodule, submodule_input_specs) |
| 111 | + |
| 112 | + |
| 113 | +def _update_map_meta(node: Node, inputs: dict[str, InputSpec]) -> None: |
| 114 | + f_node, mapped_args, operands = node.args |
| 115 | + mapped_arg_specs = _collect_node_arg_specs(mapped_args) |
| 116 | + operand_specs = _collect_node_arg_specs(operands) |
| 117 | + submodule_input_specs = mapped_arg_specs + operand_specs |
| 118 | + |
| 119 | + # Resolve get_attr node to actual submodule |
| 120 | + gm = node.graph.owning_module |
| 121 | + f = getattr(gm, f_node.target) |
| 122 | + |
| 123 | + _propagate_input_spec_recursive(f, submodule_input_specs) |
| 124 | + |
| 125 | + |
| 126 | +def _update_while_loop_meta(node: Node, inputs: dict[str, InputSpec]) -> None: |
| 127 | + cond_fn_node, body_fn_node, carried_inputs, additional_inputs = node.args |
| 128 | + carried_specs = _collect_node_arg_specs(carried_inputs) |
| 129 | + additional_specs = _collect_node_arg_specs(additional_inputs) |
| 130 | + submodule_input_specs = carried_specs + additional_specs |
| 131 | + |
| 132 | + gm = node.graph.owning_module |
| 133 | + cond_fn = getattr(gm, cond_fn_node.target) |
| 134 | + body_fn = getattr(gm, body_fn_node.target) |
| 135 | + |
| 136 | + _propagate_input_spec_recursive(cond_fn, submodule_input_specs) |
| 137 | + _propagate_input_spec_recursive(body_fn, submodule_input_specs) |
| 138 | + |
| 139 | + |
| 140 | +def _update_scan_meta(node: Node, inputs: dict[str, InputSpec]) -> None: |
| 141 | + combine_fn_node, init, xs, additional_inputs = node.args |
| 142 | + init_specs = _collect_node_arg_specs(init) |
| 143 | + xs_specs = _collect_node_arg_specs(xs) |
| 144 | + additional_input_specs = _collect_node_arg_specs(additional_inputs) |
| 145 | + submodule_input_specs = init_specs + xs_specs + additional_input_specs |
| 146 | + |
| 147 | + # Resolve get_attr node to actual submodule |
| 148 | + gm = node.graph.owning_module |
| 149 | + combine_fn = getattr(gm, combine_fn_node.target) |
| 150 | + |
| 151 | + _propagate_input_spec_recursive(combine_fn, submodule_input_specs) |
0 commit comments