@@ -122,6 +122,8 @@ class _ProgramState:
122
122
# Delegate data stored directly in the flatbuffer. Pointed to by BackendDelegateDataReference,
123
123
# and should be copied to Program.backend_delegate_data.
124
124
backend_delegate_data : List [BackendDelegateInlineData ] = field (default_factory = list )
125
+ # Delegate cache that is used across all entry points. Key is the hash of the delegated payload.
126
+ backend_delegate_data_cache : Dict [str , int ] = field (default_factory = dict )
125
127
126
128
# Constants are optionally stored in external files.
127
129
# Aggregate unique external constants into one buffer.
@@ -144,7 +146,8 @@ class _EmitterState:
144
146
operators : List [Operator ]
145
147
delegates : List [BackendDelegate ]
146
148
operator_cache : Dict [Tuple [str , str ], int ]
147
- delegate_cache : Dict [bytes , int ]
149
+ # delegate_cache: the key is hash(delegated_payload) and the value is the index in delegates
150
+ delegate_cache : Dict [str , int ]
148
151
emit_stacktrace : bool
149
152
150
153
spec2id_dict : Dict [TensorSpec , int ] = field (default_factory = dict )
@@ -1073,8 +1076,8 @@ def _emit_delegate(
1073
1076
"""Emit the delegates inputs and outputs as specified by the schema, then emit the
1074
1077
delegate's blob."""
1075
1078
processed_bytes = lowered_module .processed_bytes
1076
-
1077
- delegate_index = self .emitter_state .delegate_cache .get (processed_bytes )
1079
+ hashed = hashlib . sha256 ( processed_bytes ). hexdigest ()
1080
+ delegate_index = self .emitter_state .delegate_cache .get (hashed )
1078
1081
delegate_ret = None
1079
1082
1080
1083
if isinstance (self .node .meta ["spec" ], list ):
@@ -1112,10 +1115,16 @@ def _emit_delegate(
1112
1115
if delegate_index is None :
1113
1116
# Allocate an entry for the data. TODO(T150113674): Reuse any duplicate entries if
1114
1117
# present.
1115
- data_index : int = len ( self . program_state . backend_delegate_data )
1116
- self . program_state . backend_delegate_data . append (
1117
- BackendDelegateInlineData ( data = processed_bytes )
1118
+ hashed = hashlib . sha256 ( processed_bytes ). hexdigest ( )
1119
+ data_index : Optional [ int ] = (
1120
+ self . program_state . backend_delegate_data_cache . get ( hashed )
1118
1121
)
1122
+ if data_index is None :
1123
+ data_index = len (self .program_state .backend_delegate_data )
1124
+ self .program_state .backend_delegate_data_cache [hashed ] = data_index
1125
+ self .program_state .backend_delegate_data .append (
1126
+ BackendDelegateInlineData (data = processed_bytes )
1127
+ )
1119
1128
1120
1129
backend_delegate = BackendDelegate (
1121
1130
id = lowered_module .backend_id ,
@@ -1126,7 +1135,7 @@ def _emit_delegate(
1126
1135
)
1127
1136
delegate_index = len (self .emitter_state .delegate_cache )
1128
1137
self .emitter_state .delegates .append (backend_delegate )
1129
- self .emitter_state .delegate_cache [processed_bytes ] = delegate_index
1138
+ self .emitter_state .delegate_cache [hashed ] = delegate_index
1130
1139
1131
1140
# TODO(angelayi) Will need to emit the kwargs too, in the correct order according to the
1132
1141
# function's spec and with default arguments. This requires us to store the function's spec
0 commit comments