Skip to content

Commit a727b55

Browse files
authored
fix delegate cache duplicate bug
Differential Revision: D67067997 Pull Request resolved: #7281
1 parent ba6c552 commit a727b55

File tree

5 files changed

+79
-10
lines changed

5 files changed

+79
-10
lines changed

exir/_serialize/_program.py

+6-2
Original file line numberDiff line numberDiff line change
@@ -224,6 +224,7 @@ def _extract_delegate_segments(
224224
"""
225225
remaining_inline: List[BackendDelegateInlineData] = []
226226
inline_indices_seen: set[int] = set()
227+
segment_index_map: dict[bytes, int] = {}
227228
for plan in program.execution_plan:
228229
for delegate in plan.delegates:
229230
if delegate.processed.location != DataLocation.INLINE:
@@ -249,8 +250,11 @@ def _extract_delegate_segments(
249250
inline_indices_seen.add(delegate.processed.index)
250251
if inline.data:
251252
# Move the delegate data out of the program.
252-
segment_index = len(segments)
253-
segments.append(Cord(inline.data))
253+
segment_index = segment_index_map.get(inline.data)
254+
if segment_index is None:
255+
segment_index = len(segments)
256+
segments.append(Cord(inline.data))
257+
segment_index_map[inline.data] = segment_index
254258
delegate.processed = BackendDelegateDataReference(
255259
location=DataLocation.SEGMENT,
256260
index=segment_index,

exir/backend/test/demos/rpc/TARGETS

+1
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@ runtime.python_library(
2828
],
2929
visibility = [
3030
"//executorch/exir/backend/test/...",
31+
"//executorch/exir/emit/test/...",
3132
],
3233
deps = [
3334
":executor_backend_preprocess",

exir/emit/_emitter.py

+16-7
Original file line numberDiff line numberDiff line change
@@ -122,6 +122,8 @@ class _ProgramState:
122122
# Delegate data stored directly in the flatbuffer. Pointed to by BackendDelegateDataReference,
123123
# and should be copied to Program.backend_delegate_data.
124124
backend_delegate_data: List[BackendDelegateInlineData] = field(default_factory=list)
125+
# Delegate cache that is used across all entry points. Key is the hash of the delegated payload.
126+
backend_delegate_data_cache: Dict[str, int] = field(default_factory=dict)
125127

126128
# Constants are optionally stored in external files.
127129
# Aggregate unique external constants into one buffer.
@@ -144,7 +146,8 @@ class _EmitterState:
144146
operators: List[Operator]
145147
delegates: List[BackendDelegate]
146148
operator_cache: Dict[Tuple[str, str], int]
147-
delegate_cache: Dict[bytes, int]
149+
# delegate_cache: the key is hash(delegated_payload) and the value is the index in delegates
150+
delegate_cache: Dict[str, int]
148151
emit_stacktrace: bool
149152

150153
spec2id_dict: Dict[TensorSpec, int] = field(default_factory=dict)
@@ -1073,8 +1076,8 @@ def _emit_delegate(
10731076
"""Emit the delegates inputs and outputs as specified by the schema, then emit the
10741077
delegate's blob."""
10751078
processed_bytes = lowered_module.processed_bytes
1076-
1077-
delegate_index = self.emitter_state.delegate_cache.get(processed_bytes)
1079+
hashed = hashlib.sha256(processed_bytes).hexdigest()
1080+
delegate_index = self.emitter_state.delegate_cache.get(hashed)
10781081
delegate_ret = None
10791082

10801083
if isinstance(self.node.meta["spec"], list):
@@ -1112,10 +1115,16 @@ def _emit_delegate(
11121115
if delegate_index is None:
11131116
# Allocate an entry for the data. TODO(T150113674): Reuse any duplicate entries if
11141117
# present.
1115-
data_index: int = len(self.program_state.backend_delegate_data)
1116-
self.program_state.backend_delegate_data.append(
1117-
BackendDelegateInlineData(data=processed_bytes)
1118+
hashed = hashlib.sha256(processed_bytes).hexdigest()
1119+
data_index: Optional[int] = (
1120+
self.program_state.backend_delegate_data_cache.get(hashed)
11181121
)
1122+
if data_index is None:
1123+
data_index = len(self.program_state.backend_delegate_data)
1124+
self.program_state.backend_delegate_data_cache[hashed] = data_index
1125+
self.program_state.backend_delegate_data.append(
1126+
BackendDelegateInlineData(data=processed_bytes)
1127+
)
11191128

11201129
backend_delegate = BackendDelegate(
11211130
id=lowered_module.backend_id,
@@ -1126,7 +1135,7 @@ def _emit_delegate(
11261135
)
11271136
delegate_index = len(self.emitter_state.delegate_cache)
11281137
self.emitter_state.delegates.append(backend_delegate)
1129-
self.emitter_state.delegate_cache[processed_bytes] = delegate_index
1138+
self.emitter_state.delegate_cache[hashed] = delegate_index
11301139

11311140
# TODO(angelayi) Will need to emit the kwargs too, in the correct order according to the
11321141
# function's spec and with default arguments. This requires us to store the function's spec

exir/emit/test/TARGETS

+1
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@ python_unittest(
1616
"//executorch/exir:lib",
1717
"//executorch/exir:print_program",
1818
"//executorch/exir:schema",
19+
"//executorch/exir/backend/test/demos/rpc:executor_backend_partitioner",
1920
"//executorch/exir/backend:backend_api",
2021
"//executorch/exir/emit:lib",
2122
"//executorch/exir/passes:const_prop_pass",

exir/emit/test/test_emit.py

+55-1
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,9 @@
2727
from executorch.exir._serialize._program import deserialize_pte_binary
2828
from executorch.exir.backend.backend_api import to_backend
2929
from executorch.exir.backend.backend_details import BackendDetails, PreprocessResult
30+
from executorch.exir.backend.test.demos.rpc.executor_backend_partitioner import (
31+
ExecutorBackendPartitioner,
32+
)
3033
from executorch.exir.dialects._ops import ops as exir_ops
3134
from executorch.exir.emit import emit_program # noqa
3235
from executorch.exir.error import InternalError
@@ -63,7 +66,7 @@
6366
from functorch.experimental import control_flow
6467
from torch import nn
6568

66-
from torch.export import Dim, export
69+
from torch.export import Dim, export, export_for_training
6770

6871

6972
class WrapperModule(torch.nn.Module):
@@ -1679,3 +1682,54 @@ def forward(self, x):
16791682
]
16801683
self.assertEqual(external_map["linear.weight"], 0)
16811684
self.assertEqual(external_map["linear.bias"], 1)
1685+
1686+
def test_delegate_deduplicate(self) -> None:
1687+
class SharedModule(torch.nn.Module):
1688+
def __init__(self):
1689+
super().__init__()
1690+
self.linear = torch.nn.Linear(2, 2)
1691+
1692+
def forward(self, x):
1693+
return self.linear(x)
1694+
1695+
class Module1(torch.nn.Module):
1696+
def __init__(self, shared_module):
1697+
super().__init__()
1698+
self.shared_module = shared_module
1699+
1700+
def forward(self, x):
1701+
return self.shared_module(x)
1702+
1703+
class Module2(torch.nn.Module):
1704+
def __init__(self, shared_module):
1705+
super().__init__()
1706+
self.shared_module = shared_module
1707+
1708+
def forward(self, x):
1709+
return self.shared_module(x)
1710+
1711+
shared_module = SharedModule()
1712+
module_1 = Module1(shared_module)
1713+
module_2 = Module2(shared_module)
1714+
example_inputs = (torch.randn(2, 2),)
1715+
module_1(*example_inputs)
1716+
module_2(*example_inputs)
1717+
1718+
ep1 = export_for_training(module_1, example_inputs)
1719+
ep2 = export_for_training(module_2, example_inputs)
1720+
1721+
edge_program_manager = exir.to_edge(
1722+
{"forward1": ep1, "forward2": ep2},
1723+
compile_config=exir.EdgeCompileConfig(
1724+
_check_ir_validity=False, _use_edge_ops=True
1725+
),
1726+
)
1727+
1728+
edge_program_manager = edge_program_manager.to_backend(
1729+
ExecutorBackendPartitioner()
1730+
).to_executorch()
1731+
1732+
# Check that there is only one delegate because two methods are exactly the same
1733+
self.assertEqual(
1734+
len(edge_program_manager.executorch_program.backend_delegate_data), 1
1735+
)

0 commit comments

Comments
 (0)