From 7a8e8bf193e529faf84c8e5d3b68384446920fb1 Mon Sep 17 00:00:00 2001 From: Jacob Szwejbka Date: Wed, 9 Apr 2025 15:33:07 -0700 Subject: [PATCH] memory_planning algos take the specs as inputs instead of calculating them themselves (#9952) Summary: Pull Request resolved: https://github.com/pytorch/executorch/pull/9952 Refactor the algos to not calculate the active set of nodes themselves. Future work should refactor apply_algo more to not be recursive and actually smartly handle lifespans on control flow. Made the memory planning suite a class so that the algo list is easily configureable without having to make a wrapper function. Reviewed By: tarun292, skrtskrtfb Differential Revision: D72600295 --- backends/cadence/aot/memory_planning.py | 22 +-- backends/vulkan/vulkan_preprocess.py | 6 +- exir/memory_planning.py | 251 ++++++++++++++---------- exir/passes/memory_planning_pass.py | 10 +- exir/tests/test_memory_planning.py | 10 +- 5 files changed, 168 insertions(+), 131 deletions(-) diff --git a/backends/cadence/aot/memory_planning.py b/backends/cadence/aot/memory_planning.py index cfe1b9ab9d8..3c6c518f16a 100644 --- a/backends/cadence/aot/memory_planning.py +++ b/backends/cadence/aot/memory_planning.py @@ -12,7 +12,7 @@ import math import typing from functools import partial -from typing import Iterable, List, Optional, Tuple +from typing import Iterable, List, Optional, Set, Tuple import torch from executorch.backends.cadence.aot.memory_constraints import ( @@ -73,11 +73,11 @@ def collect_specs_from_graph_module( # the fastest memory available # flake8: noqa 'position_based_greedy_with_hierarchy' is too complex (13) def position_based_greedy_with_hierarchy( - graph_module: torch.fx.GraphModule, alignment: int, + specs: Set[TensorSpec], + graph_module: torch.fx.GraphModule, graph_signature: ExportGraphSignature, - alloc_graph_input: bool, - alloc_graph_output: bool, + extra_padding: int = 0, *, memory_config: MemoryConfig, mem_constraints: MemConstraints, @@ -119,9 +119,7 @@ def memory_available(spec: TensorSpec) -> bool: # Iterate over all the specs in sorted order for spec in sorted( - collect_specs_from_graph_module( - graph_module, graph_signature, alloc_graph_input, alloc_graph_output - ), + specs, key=lambda spec: spec.allocated_memory, reverse=True, ): @@ -167,11 +165,11 @@ def memory_available(spec: TensorSpec) -> bool: # Greedy tensor placement with the heuristics from arxiv.org/pdf/2001.03288.pdf def greedy_by_size_for_offset_calculation_with_hierarchy( - graph_module: torch.fx.GraphModule, alignment: int, + specs: Set[TensorSpec], + graph_module: torch.fx.GraphModule, graph_signature: ExportGraphSignature, - alloc_graph_input: bool, - alloc_graph_output: bool, + extra_padding: int = 0, *, memory_config: MemoryConfig, mem_constraints: MemConstraints, @@ -199,9 +197,7 @@ def greedy_by_size_for_offset_calculation_with_hierarchy( # Iterate over all the specs in sorted order for spec in sorted( - collect_specs_from_graph_module( - graph_module, graph_signature, alloc_graph_input, alloc_graph_output - ), + specs, key=lambda spec: spec.allocated_memory, reverse=True, ): diff --git a/backends/vulkan/vulkan_preprocess.py b/backends/vulkan/vulkan_preprocess.py index 1c1c51bb58a..188311e5f2c 100644 --- a/backends/vulkan/vulkan_preprocess.py +++ b/backends/vulkan/vulkan_preprocess.py @@ -47,7 +47,7 @@ ) from executorch.exir.backend.utils import DelegateMappingBuilder -from executorch.exir.memory_planning import greedy, memory_planning_algorithm_suite +from executorch.exir.memory_planning import greedy, MemoryPlanningAlgorithmSuite from executorch.exir.pass_base import ExportPass, PassBase from executorch.exir.passes import MemoryPlanningPass, SpecPropPass @@ -199,8 +199,8 @@ def preprocess( # noqa: C901 # Finally, apply dynamic shape passes and memory planning pass. These passes # must be applied only when the graph structure is finalized. greedy_memory_planning = partial(greedy, allow_overlapping_allocations=False) - mem_planning_suite = partial( - memory_planning_algorithm_suite, algo_list=[greedy_memory_planning] + mem_planning_suite = MemoryPlanningAlgorithmSuite( + algo_list=[greedy_memory_planning] ) program = apply_passes( program, diff --git a/exir/memory_planning.py b/exir/memory_planning.py index 3f45276c9e2..17640a9f7aa 100644 --- a/exir/memory_planning.py +++ b/exir/memory_planning.py @@ -731,53 +731,43 @@ def _contains_xnnpack_delegate(graph_module: torch.fx.GraphModule) -> bool: def greedy( - graph_module: torch.fx.GraphModule, alignment: int, - graph_signature: Optional[ExportGraphSignature] = None, - alloc_graph_input: bool = True, - alloc_graph_output: bool = True, + specs: Set[TensorSpec], + graph_module: torch.fx.GraphModule, + graph_signature: ExportGraphSignature, + extra_padding: int = 0, + *, allow_overlapping_allocations: bool = True, ) -> MemoryAlgoResult: r"""Greedy algorithm to allocate memory for tensors in the graph. - alloc_graph_input: If set to true, the algorithm will allocate memory for graph input. - alloc_graph_output: If set to true, the algorithm will allocate memory for graph output. - allow_overlapping_allocations: If set to true, allows for allocations that overlap - in their lifetime but are at different offsets in the storage. By default true. - This flag is added to allow for Vulkan to use MemoryPlanningPass with overlapping - allocations disabled + + Args: + alignment: Memory alignment requirement + specs: Set of TensorSpec objects with updated lifetimes + graph_module: Graph module + graph_signature: Graph signature + extra_padding: Additional padding to add to each memory buffer (in bytes) + allow_overlapping_allocations: If set to true, allows for allocations that overlap + in their lifetime but are at different offsets in the storage. By default true. + This flag is added to allow for Vulkan to use MemoryPlanningPass with overlapping + allocations disabled + + Returns: + MemoryAlgoResult containing the allocation decisions """ greedy_result = MemoryAlgoResult({}, []) - # padding allocation with 64 bytes. - # this requirement is really for XNNPACK backend which can read tensors - # beyond the end of the tensor. This is done for performance - # optimizations in XNNPACK. - # While accounting for backend specific requirement is not the right choice - # in backend agnostic memory planning, we do it here as it seems most appropriate. - # Right now this applies to greedy only so any other - # algorithm that plans memory for XNNPACK backend will - # not have this. - extra_padded_bytes = 0 - if _contains_xnnpack_delegate(graph_module): - extra_padded_bytes = 64 spec2obj = {} shared_objects = defaultdict(list) - # Don't do assertion in collect_specs_from_nodes if we have already encountered - # and ignored some to_out_variant errors. - do_assertion = not getattr(graph_module, "encounter_to_out_var_failure", False) + # For each tensor, pick the available shared object with closest size to # the tensor. If there are no available shared object left, create a new # one. import bisect sorted_specs = [] - for spec in collect_specs_from_nodes( - graph_module.graph.nodes, - graph_signature, - do_assertion=do_assertion, - ignore_graph_input=not alloc_graph_input, - ignore_graph_output=not alloc_graph_output, - ): + for spec in specs: bisect.insort(sorted_specs, spec, key=lambda x: x.allocated_memory) + sorted_specs.reverse() for spec in sorted_specs: @@ -806,15 +796,13 @@ def greedy( for mem_id in shared_objects: input_total_size = 0 if bufsizes := getattr(graph_module, "input_mem_buffer_sizes", None): - # pyre-fixme[6]: For 1st argument expected - # `pyre_extensions.ReadOnly[Sized]` but got `Union[Tensor, Module]`. + assert isinstance(bufsizes, list) if len(bufsizes) > mem_id: - # pyre-fixme[29]: `Union[BoundMethod[typing.Callable(torch._C.Ten... input_total_size = bufsizes[mem_id] total_sizes[mem_id] = materialize_buffer( shared_objects[mem_id], input_total_size ) - total_sizes[mem_id] += extra_padded_bytes + total_sizes[mem_id] += extra_padding # Since we now know the number of shared objects we need and the size of # each shared object, we can assign offset in the memory buffer for each @@ -838,72 +826,107 @@ def greedy( return greedy_result -def memory_planning_algorithm_suite( - graph_module: torch.fx.GraphModule, - alignment: int, - graph_signature: Optional[ExportGraphSignature] = None, - alloc_graph_input: bool = True, - alloc_graph_output: bool = True, - allow_overlapping_allocations: bool = True, - algo_list: Optional[List[Callable[..., MemoryAlgoResult]]] = None, -) -> List[int]: - r""" - Memory planning algorithm suite that runs a list of memory planning algorithms - and returns the result of the algorithm that minimizes the total memory usage. - """ - if algo_list is None: - algo_list = [greedy] - mem_algo_results = {} - for algo in algo_list: - if isinstance(algo, functools.partial): - name = algo.func.__name__ - else: - name = getattr(algo, "__name__", None) - # Run this memory planning algorithm and store the result in mem_algo_results - # with the name of the algorithm as the key. - mem_algo_results[name] = algo( - graph_module, - alignment, - graph_signature, - alloc_graph_input, - alloc_graph_output, - ) +class MemoryPlanningAlgorithmSuite: + def __init__( + self, + algo_list: Optional[List[Callable[..., MemoryAlgoResult]]] = None, + ) -> None: + if algo_list is None: + algo_list = [greedy] + self.algo_list: List[Callable[..., MemoryAlgoResult]] = algo_list - # All the algorithms should have the same number of buffers allocated. - assert ( - len( - { - len(mem_algo_result.bufsizes) - for mem_algo_result in mem_algo_results.values() - } + def __call__( + self, + alignment: int, + specs: Set[TensorSpec], + graph_module: torch.fx.GraphModule, + graph_signature: ExportGraphSignature, + extra_padding: int, + ) -> List[int]: + r""" + Memory planning algorithm suite that runs a list of memory planning algorithms + and returns the result of the algorithm that minimizes the total memory usage. + + Args: + graph_module: The graph module to allocate memory for + alignment: Memory alignment requirement + graph_signature: Optional graph signature + alloc_graph_input: Whether to allocate memory for graph input + alloc_graph_output: Whether to allocate memory for graph output + allow_overlapping_allocations: Whether to allow overlapping allocations + algo_list: List of memory planning algorithms to run + specs: Optional set of TensorSpec objects with updated lifetimes. If None, they will be + calculated from the graph_module. + + Returns: + List of buffer sizes for each memory hierarchy + """ + + mem_algo_results = {} + for algo in self.algo_list: + if isinstance(algo, functools.partial): + name = algo.func.__name__ + else: + name = getattr(algo, "__name__", None) + + mem_algo_results[name] = algo( + alignment, + specs, + graph_module, + graph_signature, + extra_padding, + ) + + # All the algorithms should have the same number of buffers allocated. + assert ( + len( + { + len(mem_algo_result.bufsizes) + for mem_algo_result in mem_algo_results.values() + } + ) + == 1 + ), "Different memory planning algorithms should have the same number of buffers allocated." + + # Find the algorithm that minimizes the total memory usage. + best_algo = min( + mem_algo_results, key=lambda k: sum(mem_algo_results[k].bufsizes) ) - == 1 - ), "Different memory planning algorithms should have the same number of buffers allocated." - - # Find the algorithm that minimizes the total memory usage. - best_algo = min(mem_algo_results, key=lambda k: sum(mem_algo_results[k].bufsizes)) - logging.debug(f"Best memory planning algo for this model is {best_algo}") - bufsizes = mem_algo_results[best_algo].bufsizes - - # Update the mem_id and mem_offset for each spec in the graph module based on the - # values provided by the best memory planning algorithm. - for spec in mem_algo_results[best_algo].spec_dict: - spec_alloc_result = mem_algo_results[best_algo].spec_dict[spec] - spec.mem_id = spec_alloc_result.mem_id - spec.mem_offset = spec_alloc_result.mem_offset - spec.mem_obj_id = spec_alloc_result.mem_obj_id + logging.debug(f"Best memory planning algo for this model is {best_algo}") + bufsizes = mem_algo_results[best_algo].bufsizes - return bufsizes + # Update the mem_id and mem_offset for each spec in the graph module based on the + # values provided by the best memory planning algorithm. + for spec in mem_algo_results[best_algo].spec_dict: + spec_alloc_result = mem_algo_results[best_algo].spec_dict[spec] + spec.mem_id = spec_alloc_result.mem_id + spec.mem_offset = spec_alloc_result.mem_offset + spec.mem_obj_id = spec_alloc_result.mem_obj_id + + return bufsizes def naive( - graph_module: torch.fx.GraphModule, alignment: int, - graph_signature: Optional[ExportGraphSignature] = None, - alloc_graph_input: bool = True, - alloc_graph_output: bool = True, + specs: Set[TensorSpec], + graph_module: torch.fx.GraphModule, + graph_signature: ExportGraphSignature, + extra_padding: int, ) -> MemoryAlgoResult: + """Naive algorithm to allocate memory for tensors in the graph. + This algorithm simply allocates memory for each tensor sequentially without reusing memory. + + Args: + alignment: Memory alignment requirement + specs: Set of TensorSpec objects with updated lifetimes + graph_module: Graph module + graph_signature: Graph signature + extra_padding: Additional padding to add to each memory buffer (in bytes) + + Returns: + MemoryAlgoResult containing the allocation decisions + """ naive_result = MemoryAlgoResult({}, []) # allocate 'allocated' bytes from buffer with id mem_id. @@ -918,14 +941,9 @@ def _allocate_buf(bufsizes: List[int], mem_id: int, allocated: int) -> int: bufsizes = getattr(graph_module, "input_mem_buffer_sizes", None) if bufsizes is None: bufsizes = [0, 0] - bufsizes = typing.cast(List[int], bufsizes) - for spec in collect_specs_from_nodes( - graph_module.graph.nodes, - graph_signature, - ignore_graph_input=not alloc_graph_input, - ignore_graph_output=not alloc_graph_output, - ): + + for spec in specs: spec_alloc_result = naive_result.spec_dict.get(spec, SpecAllocResult(0, 0, 0)) # assume a single memory layer which has mem_id 1 if spec.mem_id is None: @@ -1027,7 +1045,7 @@ def insert_calls_to_free( def apply_algo( algo: Callable[ - [torch.fx.GraphModule, int, Optional[ExportGraphSignature], bool, bool], + ..., List[int], ], graph_module: torch.fx.GraphModule, @@ -1048,11 +1066,35 @@ def apply_algo( TODO: make these optimizations once we have some baseline working. """ - specs = update_all_tensors_lifetime(graph_module, graph_signature) + # Extract the nodes and their lifespans from the graph_module + # Difficult to just filter the list of specs returned by this due to + # how we flag trainable weights. + _ = update_all_tensors_lifetime(graph_module, graph_signature) + + # Filter specs based on alloc_graph_input and alloc_graph_output + specs = collect_specs_from_nodes( + graph_module.graph.nodes, + graph_signature, + do_assertion=False, + ignore_graph_input=not alloc_graph_input, + ignore_graph_output=not alloc_graph_output, + ) + + # Get extra padding for XNNPACK if needed + extra_padding = 0 + if _contains_xnnpack_delegate(graph_module): + extra_padding = 64 + + # Pass the filtered specs to the algorithm bufsizes: List[int] = algo( - graph_module, alignment, graph_signature, alloc_graph_input, alloc_graph_output + alignment, + specs, + graph_module, + graph_signature, + extra_padding, ) - insert_calls_to_free(graph_module, specs) + + insert_calls_to_free(graph_module, set(specs)) def handle_submodule( submodule_nd: torch.fx.Node, alloc_graph_input: bool = False @@ -1063,6 +1105,7 @@ def handle_submodule( # memory planning for submodule need to be aware of the amount of # buffer already allocated. submodule.input_mem_buffer_sizes = bufsizes + bufsizes = apply_algo( algo, submodule, diff --git a/exir/passes/memory_planning_pass.py b/exir/passes/memory_planning_pass.py index f4881e7ab71..f4b3ad8a8a7 100644 --- a/exir/passes/memory_planning_pass.py +++ b/exir/passes/memory_planning_pass.py @@ -17,7 +17,7 @@ _is_out_var_node, apply_algo, get_node_tensor_specs, - memory_planning_algorithm_suite, + MemoryPlanningAlgorithmSuite, Verifier, ) from executorch.exir.operator.convert import get_out_args_from_opoverload @@ -40,9 +40,7 @@ def _callable_name(any_callable: Callable[..., Any]) -> str: class MemoryPlanningPass(PassBase): def __init__( self, - memory_planning_algo: Callable[ - ..., List[int] - ] = memory_planning_algorithm_suite, + memory_planning_algo: Optional[Callable[..., List[int]]] = None, allow_lifetime_and_storage_overlap: bool = False, alloc_graph_input: bool = True, alloc_graph_output: bool = True, @@ -54,6 +52,8 @@ def __init__( the graph input/output. The default behavior is the algorithm will allocate memory for both graph input and output. """ + if memory_planning_algo is None: + memory_planning_algo = MemoryPlanningAlgorithmSuite() self.memory_planning_algo = memory_planning_algo self.allow_lifetime_and_storage_overlap = allow_lifetime_and_storage_overlap self.alloc_graph_input = alloc_graph_input @@ -125,7 +125,7 @@ def run( # passes/stages is quite natural and avoid yet another 'context' data structure # to do the job. _ = apply_algo( - self.memory_planning_algo, + self.memory_planning_algo, # pyre-ignore[6] graph_module, self.alignment, graph_signature, diff --git a/exir/tests/test_memory_planning.py b/exir/tests/test_memory_planning.py index 8df0cfed0bf..52986aaa04c 100644 --- a/exir/tests/test_memory_planning.py +++ b/exir/tests/test_memory_planning.py @@ -8,7 +8,6 @@ import itertools import unittest -from functools import partial from typing import Any, Callable, List, Optional, Tuple, Type import executorch.exir as exir @@ -20,8 +19,8 @@ filter_nodes, get_node_tensor_specs, greedy, - memory_planning_algorithm_suite, MemoryAlgoResult, + MemoryPlanningAlgorithmSuite, naive, Verifier, ) @@ -269,7 +268,7 @@ def wrapper(self: "TestMemoryPlanning") -> None: .exported_program() .graph_module ) - mem_algo = partial(memory_planning_algorithm_suite, algo_list=[algo]) + mem_algo = MemoryPlanningAlgorithmSuite(algo_list=[algo]) graph_module = PassManager( passes=[ SpecPropPass(), @@ -497,7 +496,6 @@ def quantize(self, eager_model: nn.Module) -> nn.Module: ) return quantized_model - # pyre-ignore @parameterized.expand( [ ( @@ -514,7 +512,7 @@ def quantize(self, eager_model: nn.Module) -> nn.Module: ) def test_multiple_pools( self, - algo: Callable[..., List[int]], + algo: Callable[..., MemoryAlgoResult], expected_allocs: List[Tuple[int, int]], expected_bufsizes: List[int], ) -> None: @@ -522,7 +520,7 @@ def test_multiple_pools( export(MultiplePoolsToyModel(), (torch.ones(1),), strict=True) ) - mem_algo = partial(memory_planning_algorithm_suite, algo_list=[algo]) + mem_algo = MemoryPlanningAlgorithmSuite(algo_list=[algo]) edge_program.to_executorch( exir.ExecutorchBackendConfig( memory_planning_pass=CustomPoolMemoryPlanningPass(