Skip to content

Commit 1fb2d85

Browse files
committed
[ExecuTorch][to_backend] Enable to_backend API to leverage preprocess_all
We add a new to_backend api which essentially takes in: ``` Method --> ExportedProgram Method --> Partitioner ``` This new to_backend api will then return ``` Method --> ExportedProgram ``` in which the ExportedPrograms are lowered using the partitioner. The key difference from the other to_backend implementation is that this implementation leverages `preprocess_all`. The existing implementation goes in the following steps: 1. Partition the ExportedProgram and return a tagged module, each tag represents a partition, and nodes belonging to a partition have the tag in its metadata 2. loop through every tag(partition) 2.1. create a submodule and call submodule node for the partition 2.2. create an exported program from the submodule 2.3. lower this exported program to the specified backend_id and generate a loweredbackendmodule 2.4. replace the call submodule node with call delegate node and the generated loweredbackend module 2.5 adjust the original owning program. 3. Return the newly adjusted owning program The new implementation reorders the steps so that preprocess_all can be given all the partitioned programs at once: 1. loop through every method with a specified partitioner 2. Partition the method's corresponding ExportedProgram with the specified partitioner and return a tagged module, each tag represents a partition, and nodes belonging to a partition have the tag in its metadata 3. loop through every tag(partition) 4. create a submodule and call submodule node for the partition 5. create an exported program from the submodule 6. store a bunch of meta data like owning_graph, is_submodule, delegation_spec, exported_program in the call_submodule's metadata 7. return the list of all submodule nodes created for each partition 8. Add mapping of method to the list of partitioned submodule nodes 9. Sort the above mapping by backend type (backend_id) 10. loop through backend types: 11. lower all the submodule nodes of this backend type at once with preprocess_all 12. loop through lowered call submodule nodes and replace them with results from preprocess_all 11. Return mapping of method to lowered programs The new to_backend api is not used any where, the intent is for it to eventually be used in the EdgeProgramManager's to_backend implementation which already has and uses method_to_partitioner and method_to_program mappings. Differential Revision: [D69954542](https://our.internmc.facebook.com/intern/diff/D69954542/) ghstack-source-id: 267527354 Pull Request resolved: #9811
1 parent 502f59b commit 1fb2d85

5 files changed

+1084
-15
lines changed

exir/backend/backend_api.py

Lines changed: 331 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,8 @@
99
import logging
1010
from contextlib import contextmanager, nullcontext
1111
from functools import singledispatch
12-
from typing import Generator, List
12+
from typing import Generator, List, Dict
13+
from dataclasses import dataclass
1314

1415
import torch
1516

@@ -417,3 +418,332 @@ def to_backend(
417418
constants=tagged_exported_program.constants,
418419
verifiers=[tagged_exported_program.verifier],
419420
)
421+
422+
423+
def _create_partitions_in_graph_module(
424+
tagged_graph_module: torch.fx.GraphModule,
425+
partition_result: PartitionResult,
426+
owning_program: ExportedProgram,
427+
is_submodule: bool,
428+
) -> Dict[str, List[torch.fx.Node]]:
429+
backend_id_to_submodule_name = {}
430+
for tag, delegation_spec in partition_result.partition_tags.items():
431+
# Create partition with nodes containing this tag. There should only be
432+
# one contained submodule per tag
433+
node_list = _get_node_list_with_same_tag(
434+
tagged_graph_module, tag, owning_program
435+
)
436+
437+
if len(node_list) == 0:
438+
logging.debug(f"Did not find any nodes for tag {tag}")
439+
continue
440+
441+
logging.debug(f"For tag {tag}, found nodes {node_list}")
442+
# Tag the nodes that are params as buffers, so we can order the submodule as (Parms + Buffers) (User Inputs)
443+
444+
replace_ctx = (
445+
tagged_graph_module._set_replace_hook(
446+
owning_program.graph_signature.get_replace_hook()
447+
)
448+
if not is_submodule
449+
else nullcontext()
450+
)
451+
with replace_ctx:
452+
submodule, call_module_node = create_submodule_from_nodes(
453+
tagged_graph_module, node_list, tag
454+
)
455+
456+
tagged_graph_module_output_node = [
457+
node for node in tagged_graph_module.graph.nodes if node.op == "output"
458+
][0]
459+
submodule_output_node = [
460+
node for node in submodule.graph.nodes if node.op == "output"
461+
][0]
462+
# Copy the output node meta from the original output node, because
463+
# create_submodule_from_nodes doesn't cover the meta field
464+
submodule_output_node.meta = tagged_graph_module_output_node.meta
465+
logging.debug(f"Partitioned graph module: {tagged_graph_module}")
466+
(
467+
submodule_program,
468+
toplevel_input_specs_to_delete,
469+
toplevel_output_specs_to_delete,
470+
) = create_exported_program_from_submodule(
471+
submodule,
472+
owning_program,
473+
tag,
474+
call_module_node,
475+
is_submodule,
476+
)
477+
call_module_node.meta["backend_id"] = delegation_spec.backend_id
478+
call_module_node.meta["compile_spec"] = delegation_spec.compile_specs
479+
call_module_node.meta["submodule_program"] = submodule_program
480+
call_module_node.meta["toplevel_input_specs_to_delete"] = toplevel_input_specs_to_delete
481+
call_module_node.meta["toplevel_output_specs_to_delete"] = toplevel_output_specs_to_delete
482+
call_module_node.meta["is_submodule"] = is_submodule
483+
484+
if delegation_spec.backend_id not in backend_id_to_submodule_name:
485+
backend_id_to_submodule_name[delegation_spec.backend_id] = []
486+
487+
# The call_module_node created here might not be the same node instance as
488+
# the one in the final graph module. This is because this node might be replaced
489+
# in future edits to the graph. As a result, we just keep track of the node's name
490+
# and at the end we search for this node in our final graph module
491+
backend_id_to_submodule_name[delegation_spec.backend_id].append(call_module_node.target)
492+
493+
created_submodule_nodes = dict((key,[]) for key in backend_id_to_submodule_name.keys())
494+
for backend_id, submodule_name in backend_id_to_submodule_name.items():
495+
for node in tagged_graph_module.graph.nodes:
496+
if node.op == "call_module" and node.target in submodule_name:
497+
created_submodule_nodes[backend_id].append(node)
498+
499+
# check the number of submodule_names and submodule_nodes are equal
500+
for backend_id in created_submodule_nodes.keys():
501+
assert len(created_submodule_nodes[backend_id]) == len(backend_id_to_submodule_name[backend_id])
502+
503+
return created_submodule_nodes
504+
505+
def _create_partitions(
506+
tagged_graph_module: torch.fx.GraphModule,
507+
partition_result: PartitionResult,
508+
owning_program: ExportedProgram,
509+
is_submodule: bool = False,
510+
) -> Dict[str, List[torch.fx.Node]]:
511+
backend_id_to_call_submodules = _create_partitions_in_graph_module(
512+
tagged_graph_module, partition_result, owning_program, is_submodule
513+
)
514+
515+
# Recursively partition and lower for submodules
516+
for _, submod, _ in get_control_flow_submodules(tagged_graph_module):
517+
nested_backend_id_to_call_submodules = _create_partitions(
518+
submod, partition_result, owning_program, is_submodule=True
519+
)
520+
for backend_id, nested_submodules in nested_backend_id_to_call_submodules.items():
521+
if backend_id not in backend_id_to_call_submodules:
522+
backend_id_to_call_submodules[backend_id] = nested_submodules
523+
else:
524+
backend_id_to_call_submodules[backend_id].extend(nested_submodules)
525+
526+
return backend_id_to_call_submodules
527+
528+
def lower_all_submodules_to_backend(
529+
backend_id: str,
530+
method_to_submodules_nodes: Dict[str, List[torch.fx.Node]],
531+
method_to_tagged_edge_program: Dict[str, ExportedProgram],
532+
) -> None:
533+
"""
534+
Lower all submodules nodes given in the method_to_submodule_nodes map to backend_id.
535+
"""
536+
# The created exported program for the submodules are in the call_module node's meta data
537+
# We just map the method_to_submodule_nodes directly to the method_to_partitioned_exported_programs
538+
method_to_partitioned_program = {
539+
method_name: [node.meta["submodule_program"] for node in call_submodule_nodes]
540+
for method_name, call_submodule_nodes in method_to_submodules_nodes.items()
541+
}
542+
method_to_compile_specs = {
543+
method_name: [node.meta["compile_spec"] for node in call_submodule_nodes]
544+
for method_name, call_submodule_nodes in method_to_submodules_nodes.items()
545+
}
546+
backend_found = False
547+
for cls in BackendDetails.__subclasses__():
548+
if backend_id == cls.__name__:
549+
method_to_preprocess_result: dict[str, List[PreprocessResult]] = cls.preprocess_all(
550+
method_to_partitioned_program,
551+
method_to_compile_specs
552+
)
553+
backend_found = True
554+
555+
if not backend_found:
556+
raise NotImplementedError(f"Backend {backend_id} was not found.")
557+
558+
for method_name in method_to_preprocess_result.keys():
559+
owning_program = method_to_tagged_edge_program[method_name]
560+
list_of_preprocess_results = method_to_preprocess_result[method_name]
561+
list_of_call_submodule_nodes = method_to_submodules_nodes[method_name]
562+
list_of_compile_specs = method_to_compile_specs[method_name]
563+
assert (
564+
len(list_of_preprocess_results) == len(list_of_call_submodule_nodes),
565+
f"Expected {len(list_of_call_submodule_nodes)} preprocessed results for method {method_name} but got {len(list_of_preprocess_results)}"
566+
)
567+
for preprocess_result, call_submodule_node, compile_spec in zip(list_of_preprocess_results, list_of_call_submodule_nodes, list_of_compile_specs):
568+
submodule_program = call_submodule_node.meta["submodule_program"]
569+
lowered_module = LoweredBackendModule(
570+
edge_program=submodule_program,
571+
backend_id=backend_id,
572+
processed_bytes=preprocess_result.processed_bytes,
573+
compile_specs=compile_spec,
574+
)
575+
owning_graph_module = call_submodule_node.graph.owning_module
576+
is_submodule = call_submodule_node.meta["is_submodule"]
577+
toplevel_input_specs_to_delete = call_submodule_node.meta["toplevel_input_specs_to_delete"]
578+
toplevel_output_specs_to_delete = call_submodule_node.meta["toplevel_output_specs_to_delete"]
579+
# call delegate args should only use user_inputs
580+
call_delegate_args = []
581+
# Preserve input order as user_inputs
582+
for inp_name in submodule_program.graph_signature.user_inputs:
583+
for inp_node in call_submodule_node.all_input_nodes:
584+
if inp_node.name == inp_name:
585+
call_delegate_args.append(inp_node)
586+
break
587+
588+
def generate_debug_handle(ep: ExportedProgram) -> int:
589+
"""
590+
Generate a debug handle for the given ExportedProgram.
591+
"""
592+
debug_handle = 0
593+
for node in ep.graph_module.graph.nodes:
594+
debug_handle = max(debug_handle, node.meta.get("debug_handle", 0))
595+
return debug_handle + 1
596+
597+
# Replace the partitioned submodule with a lowered submodule
598+
# Add call_method node with function "forward"
599+
with owning_graph_module.graph.inserting_before(call_submodule_node):
600+
lowered_name = get_lowered_module_name(
601+
owning_graph_module, lowered_module
602+
)
603+
lowered_node = owning_graph_module.graph.get_attr(lowered_name)
604+
call_delegate_node = owning_graph_module.graph.call_function(
605+
executorch_call_delegate,
606+
(lowered_node,) + tuple(call_delegate_args),
607+
call_submodule_node.kwargs,
608+
)
609+
call_delegate_node.meta["debug_handle"] = generate_debug_handle(
610+
owning_program
611+
)
612+
call_delegate_node.meta["val"] = call_submodule_node.meta["val"]
613+
call_submodule_node.replace_all_uses_with(call_delegate_node)
614+
owning_graph_module.graph.erase_node(call_submodule_node)
615+
616+
if is_submodule:
617+
assert len(toplevel_input_specs_to_delete) == 0
618+
assert len(toplevel_output_specs_to_delete) == 0
619+
elif (
620+
len(toplevel_input_specs_to_delete) > 0
621+
or len(toplevel_output_specs_to_delete) > 0
622+
):
623+
_unsafe_adjust_original_program(
624+
owning_program,
625+
call_delegate_node,
626+
toplevel_input_specs_to_delete,
627+
toplevel_output_specs_to_delete,
628+
)
629+
630+
@dataclass
631+
class MethodProgramsPartitionerSpec:
632+
"""
633+
Since single dispatch for to_backend requires the first argument to be a
634+
valid class, we create the following dataclass spec to hold the dictionaries
635+
mapping the method name to the corresponding program, partitioner
636+
"""
637+
method_to_edge_program: Dict[str, ExportedProgram]
638+
method_to_partitioner: Dict[str, Partitioner]
639+
640+
@to_backend.register
641+
def _(
642+
method_edge_program_partitioners: MethodProgramsPartitionerSpec
643+
) -> Dict[str, ExportedProgram]:
644+
"""
645+
Add overloaded implementations for to_backend:
646+
647+
::
648+
649+
def to_backend(
650+
method_edge_program_partitioners: MethodProgramsPartitionerSpec
651+
) -> Dict[str, ExportedProgram]:
652+
653+
Returns a semantically-equivalent dictionary of programs to the programs given as input (represented
654+
as a graph module in Edge dialect), but with portions of the program targeted for
655+
delegation as determined by the partitioner.
656+
657+
Args:
658+
method_edge_program_partitioners: contains two mappings,
659+
- method_to_edge_program: mapping of method names to their respective programs in Edge dialect.
660+
- method_to_partitioner: mapping of method names to an instance of the partitioner, in charge with tagging
661+
portions of the specified program for delegation. A valid partitioner must return PartitionerResult
662+
including both tagged exported program and partitioner_tag: Dict[str, DelegationSpec], where each key is a tag name and
663+
the nodes with same tag will be fused a one subgraph and delegated to backend specififed in delegation spec.
664+
665+
666+
Returns:
667+
ExportedProgram: The input program, with some portions targeted for delegation.
668+
"""
669+
method_to_edge_program = method_edge_program_partitioners.method_to_edge_program
670+
method_to_partitioner = method_edge_program_partitioners.method_to_partitioner
671+
672+
partitioned_and_lowered_exported_programs = {}
673+
backend_id_to_method_submodules_map = {}
674+
method_to_tagged_exported_program = {}
675+
676+
for method_name, partitioner_instance in method_to_partitioner.items():
677+
assert (
678+
method_name in method_to_edge_program
679+
), f"Partitioner for method {method_name} is not provided"
680+
edge_program = method_to_edge_program[method_name]
681+
edge_program._validate()
682+
683+
# Use fake program, with FakeTensors in the state dict, to avoid copying large constant values.
684+
# Fall back to deepcopy if no fake mode is found. TODO(T182910699): Remove this fallback.
685+
try:
686+
fake_edge_program = get_fake_program(edge_program)
687+
except Exception as e:
688+
logging.warning(
689+
f"Error in get_fake_program for graph {edge_program.graph_module}, fallback to deepcopy: {e}"
690+
)
691+
fake_edge_program = copy.deepcopy(edge_program)
692+
partitioner_result = partitioner_instance(fake_edge_program)
693+
tagged_exported_program = partitioner_result.tagged_exported_program
694+
method_to_tagged_exported_program[method_name] = tagged_exported_program
695+
696+
# Check that the partitioner did not modify the original graph
697+
if _ENABLE_VALIDATION:
698+
assert is_identical_graph(
699+
tagged_exported_program.graph_module,
700+
edge_program.graph_module,
701+
), f"The partitioner {partitioner_instance} should not modify the graph module"
702+
else:
703+
logging.warning("Disabled validating the partitioner.")
704+
705+
assert (
706+
partitioner_result.partition_tags is not None
707+
), f"Partitioner {partitioner_instance} needs a `partition_tags` field containing a mapping of tags to delegate spec"
708+
709+
update_to_real_program(tagged_exported_program, edge_program)
710+
711+
for tag, _ in partitioner_result.partition_tags.items():
712+
_maybe_duplicate_constant_nodes(tagged_exported_program, tag)
713+
714+
backend_id_to_call_submodule_nodes = _create_partitions(
715+
tagged_exported_program.graph_module,
716+
partitioner_result,
717+
tagged_exported_program,
718+
)
719+
for backend_id, call_submodule_nodes in backend_id_to_call_submodule_nodes.items():
720+
if backend_id not in backend_id_to_method_submodules_map:
721+
backend_id_to_method_submodules_map[backend_id] = {}
722+
backend_id_to_method_submodules_map[backend_id][method_name] = call_submodule_nodes
723+
724+
for backend_id, method_to_submodule_nodes in backend_id_to_method_submodules_map.items():
725+
lower_all_submodules_to_backend(
726+
backend_id,
727+
method_to_submodule_nodes,
728+
method_to_tagged_exported_program,
729+
)
730+
731+
for method_name in method_to_edge_program.keys():
732+
if method_name in method_to_tagged_exported_program:
733+
tagged_exported_program = method_to_tagged_exported_program[method_name]
734+
partitioned_and_lowered_exported_programs[method_name] = ExportedProgram(
735+
root=tagged_exported_program.graph_module,
736+
graph=tagged_exported_program.graph_module.graph,
737+
graph_signature=tagged_exported_program.graph_signature,
738+
state_dict=tagged_exported_program.state_dict,
739+
range_constraints=copy.deepcopy(tagged_exported_program.range_constraints),
740+
module_call_graph=copy.deepcopy(tagged_exported_program.module_call_graph),
741+
example_inputs=None,
742+
constants=tagged_exported_program.constants,
743+
verifiers=[tagged_exported_program.verifier],
744+
)
745+
else:
746+
# this edge program wasn't partitioned, so we can just return it as is
747+
partitioned_and_lowered_exported_programs[method_name] = method_to_edge_program[method_name]
748+
749+
return partitioned_and_lowered_exported_programs

0 commit comments

Comments
 (0)