|
12 | 12 | import math
|
13 | 13 | import typing
|
14 | 14 | from functools import partial
|
15 |
| -from typing import Iterable, List, Optional, Tuple |
| 15 | +from typing import Iterable, List, Optional, Set, Tuple |
16 | 16 |
|
17 | 17 | import torch
|
18 | 18 | from executorch.backends.cadence.aot.memory_constraints import (
|
@@ -73,11 +73,11 @@ def collect_specs_from_graph_module(
|
73 | 73 | # the fastest memory available
|
74 | 74 | # flake8: noqa 'position_based_greedy_with_hierarchy' is too complex (13)
|
75 | 75 | def position_based_greedy_with_hierarchy(
|
76 |
| - graph_module: torch.fx.GraphModule, |
77 | 76 | alignment: int,
|
| 77 | + specs: Set[TensorSpec], |
| 78 | + graph_module: torch.fx.GraphModule, |
78 | 79 | graph_signature: ExportGraphSignature,
|
79 |
| - alloc_graph_input: bool, |
80 |
| - alloc_graph_output: bool, |
| 80 | + extra_padding: int = 0, |
81 | 81 | *,
|
82 | 82 | memory_config: MemoryConfig,
|
83 | 83 | mem_constraints: MemConstraints,
|
@@ -119,9 +119,7 @@ def memory_available(spec: TensorSpec) -> bool:
|
119 | 119 |
|
120 | 120 | # Iterate over all the specs in sorted order
|
121 | 121 | for spec in sorted(
|
122 |
| - collect_specs_from_graph_module( |
123 |
| - graph_module, graph_signature, alloc_graph_input, alloc_graph_output |
124 |
| - ), |
| 122 | + specs, |
125 | 123 | key=lambda spec: spec.allocated_memory,
|
126 | 124 | reverse=True,
|
127 | 125 | ):
|
@@ -167,11 +165,11 @@ def memory_available(spec: TensorSpec) -> bool:
|
167 | 165 |
|
168 | 166 | # Greedy tensor placement with the heuristics from arxiv.org/pdf/2001.03288.pdf
|
169 | 167 | def greedy_by_size_for_offset_calculation_with_hierarchy(
|
170 |
| - graph_module: torch.fx.GraphModule, |
171 | 168 | alignment: int,
|
| 169 | + specs: Set[TensorSpec], |
| 170 | + graph_module: torch.fx.GraphModule, |
172 | 171 | graph_signature: ExportGraphSignature,
|
173 |
| - alloc_graph_input: bool, |
174 |
| - alloc_graph_output: bool, |
| 172 | + extra_padding: int = 0, |
175 | 173 | *,
|
176 | 174 | memory_config: MemoryConfig,
|
177 | 175 | mem_constraints: MemConstraints,
|
@@ -199,9 +197,7 @@ def greedy_by_size_for_offset_calculation_with_hierarchy(
|
199 | 197 |
|
200 | 198 | # Iterate over all the specs in sorted order
|
201 | 199 | for spec in sorted(
|
202 |
| - collect_specs_from_graph_module( |
203 |
| - graph_module, graph_signature, alloc_graph_input, alloc_graph_output |
204 |
| - ), |
| 200 | + specs, |
205 | 201 | key=lambda spec: spec.allocated_memory,
|
206 | 202 | reverse=True,
|
207 | 203 | ):
|
|
0 commit comments