Skip to content

Commit c7cc64d

Browse files
JacobSzwejbkafacebook-github-bot
authored andcommitted
memory_planning algos take the specs as inputs instead of calculating them themselves (#9952)
Summary: Refactor the algos to not calculate the active set of nodes themselves. Future work should refactor apply_algo more to not be recursive and actually smartly handle lifespans on control flow. Made the memory planning suite a class so that the algo list is easily configureable without having to make a wrapper function. Reviewed By: skrtskrtfb Differential Revision: D72600295
1 parent 3a940da commit c7cc64d

File tree

5 files changed

+162
-131
lines changed

5 files changed

+162
-131
lines changed

backends/cadence/aot/memory_planning.py

+9-13
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
import math
1313
import typing
1414
from functools import partial
15-
from typing import Iterable, List, Optional, Tuple
15+
from typing import Iterable, List, Optional, Set, Tuple
1616

1717
import torch
1818
from executorch.backends.cadence.aot.memory_constraints import (
@@ -73,11 +73,11 @@ def collect_specs_from_graph_module(
7373
# the fastest memory available
7474
# flake8: noqa 'position_based_greedy_with_hierarchy' is too complex (13)
7575
def position_based_greedy_with_hierarchy(
76-
graph_module: torch.fx.GraphModule,
7776
alignment: int,
77+
specs: Set[TensorSpec],
78+
graph_module: torch.fx.GraphModule,
7879
graph_signature: ExportGraphSignature,
79-
alloc_graph_input: bool,
80-
alloc_graph_output: bool,
80+
extra_padding: int = 0,
8181
*,
8282
memory_config: MemoryConfig,
8383
mem_constraints: MemConstraints,
@@ -119,9 +119,7 @@ def memory_available(spec: TensorSpec) -> bool:
119119

120120
# Iterate over all the specs in sorted order
121121
for spec in sorted(
122-
collect_specs_from_graph_module(
123-
graph_module, graph_signature, alloc_graph_input, alloc_graph_output
124-
),
122+
specs,
125123
key=lambda spec: spec.allocated_memory,
126124
reverse=True,
127125
):
@@ -167,11 +165,11 @@ def memory_available(spec: TensorSpec) -> bool:
167165

168166
# Greedy tensor placement with the heuristics from arxiv.org/pdf/2001.03288.pdf
169167
def greedy_by_size_for_offset_calculation_with_hierarchy(
170-
graph_module: torch.fx.GraphModule,
171168
alignment: int,
169+
specs: Set[TensorSpec],
170+
graph_module: torch.fx.GraphModule,
172171
graph_signature: ExportGraphSignature,
173-
alloc_graph_input: bool,
174-
alloc_graph_output: bool,
172+
extra_padding: int = 0,
175173
*,
176174
memory_config: MemoryConfig,
177175
mem_constraints: MemConstraints,
@@ -199,9 +197,7 @@ def greedy_by_size_for_offset_calculation_with_hierarchy(
199197

200198
# Iterate over all the specs in sorted order
201199
for spec in sorted(
202-
collect_specs_from_graph_module(
203-
graph_module, graph_signature, alloc_graph_input, alloc_graph_output
204-
),
200+
specs,
205201
key=lambda spec: spec.allocated_memory,
206202
reverse=True,
207203
):

backends/vulkan/vulkan_preprocess.py

+2-4
Original file line numberDiff line numberDiff line change
@@ -47,7 +47,7 @@
4747
)
4848
from executorch.exir.backend.utils import DelegateMappingBuilder
4949

50-
from executorch.exir.memory_planning import greedy, memory_planning_algorithm_suite
50+
from executorch.exir.memory_planning import greedy, MemoryPlanningAlgorithmSuite
5151
from executorch.exir.pass_base import ExportPass, PassBase
5252

5353
from executorch.exir.passes import MemoryPlanningPass, SpecPropPass
@@ -199,9 +199,7 @@ def preprocess( # noqa: C901
199199
# Finally, apply dynamic shape passes and memory planning pass. These passes
200200
# must be applied only when the graph structure is finalized.
201201
greedy_memory_planning = partial(greedy, allow_overlapping_allocations=False)
202-
mem_planning_suite = partial(
203-
memory_planning_algorithm_suite, algo_list=[greedy_memory_planning]
204-
)
202+
mem_planning_suite = MemoryPlanningAlgorithmSuite(algo_list=[greedy_memory_planning])
205203
program = apply_passes(
206204
program,
207205
[

0 commit comments

Comments
 (0)