@@ -150,23 +150,94 @@ def _put_tensors_thread_worker(self, batch_keys: list[str], batch_tensors: list[
150150
151151 try :
152152 results = self ._store .batch_upsert_from (batch_keys , batch_ptrs , batch_sizes , config = self .replica_config )
153- if not all (r == 0 for r in results ):
154- failed_indices = [j for j , r in enumerate (results ) if r != 0 ]
155- error_codes = [results [j ] for j in failed_indices ]
153+ if len (results ) != len (batch_keys ):
154+ raise RuntimeError (f"batch_upsert_from returned { len (results )} results, expected { len (batch_keys )} " )
155+
156+ failed_indices = [j for j , r in enumerate (results ) if r != 0 ]
157+ if not failed_indices :
158+ return
159+
160+ current_failed_keys = [batch_keys [i ] for i in failed_indices ]
161+ current_failed_codes = [results [i ] for i in failed_indices ]
162+ current_failed_indices = failed_indices
163+
164+ logger .error (
165+ f"batch_upsert_from failed for keys { current_failed_keys } with error codes { current_failed_codes } . "
166+ f"Retrying up to { MAX_RETRIES } times..."
167+ )
168+
169+ for attempt in range (1 , MAX_RETRIES + 1 ):
170+ retry_ptrs = [batch_ptrs [i ] for i in current_failed_indices ]
171+ retry_sizes = [batch_sizes [i ] for i in current_failed_indices ]
172+
173+ retry_results = self ._store .batch_upsert_from (
174+ current_failed_keys , retry_ptrs , retry_sizes , config = self .replica_config
175+ )
176+
177+ next_failed_indices = []
178+ next_failed_keys = []
179+ next_failed_codes = []
180+
181+ for i , ret in enumerate (retry_results ):
182+ if ret != 0 :
183+ next_failed_indices .append (current_failed_indices [i ])
184+ next_failed_keys .append (current_failed_keys [i ])
185+ next_failed_codes .append (ret )
186+
187+ if not next_failed_indices :
188+ logger .info ("batch_upsert_from succeeded after retransmission." )
189+ break # All retries in this attempt succeeded.
190+
191+ logger .error (
192+ f"batch_upsert_from retry { attempt } /{ MAX_RETRIES } failed for { len (next_failed_keys )} keys "
193+ f"with error codes { next_failed_codes } ."
194+ )
195+
196+ current_failed_indices = next_failed_indices
197+ current_failed_keys = next_failed_keys
198+ current_failed_codes = next_failed_codes
199+
200+ if attempt < MAX_RETRIES :
201+ time .sleep (RETRY_DELAY_SECONDS )
202+ else :
156203 raise RuntimeError (
157- f"batch_upsert_from failed for indices { failed_indices } with error codes: { error_codes } "
204+ f"batch_upsert_from failed for keys { current_failed_keys } with error codes "
205+ f"{ current_failed_codes } after retrying { MAX_RETRIES } times."
158206 )
207+
159208 finally :
160209 self ._unregister_all_buffers (batch_ptr_reduced )
161210
162211 def _put_bytes_thread_worker (self , batch_keys : list [str ], batch_values : list [Any ]):
163212 """Worker thread for putting batch of non-tensors to MooncakeStore."""
164213
165- batch_values = [pickle .dumps (v , protocol = pickle .HIGHEST_PROTOCOL ) for v in batch_values ]
214+ serialized_values = [pickle .dumps (v , protocol = pickle .HIGHEST_PROTOCOL ) for v in batch_values ]
166215
167- ret = self ._store .upsert_batch (batch_keys , batch_values , self .replica_config )
168- if ret != 0 :
169- raise RuntimeError (f"upsert_batch failed with error code: { ret } " )
216+ # FIXME: Use element-level ret value to precise retransmit when MooncakeStore supports
217+ ret = self ._store .upsert_batch (batch_keys , serialized_values , self .replica_config )
218+ if ret == 0 :
219+ return
220+
221+ logger .error (
222+ f"upsert_batch failed for { len (batch_keys )} keys with error code: { ret } . "
223+ f"Retrying up to { MAX_RETRIES } times..."
224+ )
225+
226+ for attempt in range (1 , MAX_RETRIES + 1 ):
227+ ret = self ._store .upsert_batch (batch_keys , serialized_values , self .replica_config )
228+ if ret == 0 :
229+ logger .info ("upsert_batch succeeded after retransmission." )
230+ return
231+
232+ logger .error (
233+ f"upsert_batch retry { attempt } /{ MAX_RETRIES } failed for { len (batch_keys )} keys with error code: { ret } ."
234+ )
235+ if attempt < MAX_RETRIES :
236+ time .sleep (RETRY_DELAY_SECONDS )
237+
238+ raise RuntimeError (
239+ f"upsert_batch failed for { len (batch_keys )} keys with error code: { ret } after retrying { MAX_RETRIES } times."
240+ )
170241
171242 def get (
172243 self ,
@@ -274,8 +345,14 @@ def _get_tensors_thread_worker(
274345 next_failed_codes .append (ret )
275346
276347 if not next_failed_indices :
348+ logger .info ("batch_get_into succeeded after retransmission." )
277349 break # All retries in this attempt succeeded.
278350
351+ logger .error (
352+ f"batch_get_into retry { attempt } /{ MAX_RETRIES } failed for { len (next_failed_keys )} keys "
353+ f"with error codes { next_failed_codes } ."
354+ )
355+
279356 # Narrow down to still-failed items for the next retry attempt.
280357 current_failed_indices = next_failed_indices
281358 current_failed_keys = next_failed_keys
@@ -300,7 +377,7 @@ def _get_bytes_thread_worker(self, batch_keys: list[str], indexes: list[int]) ->
300377 if len (raw_results ) != len (batch_keys ):
301378 raise RuntimeError (f"get_batch returned { len (raw_results )} items, expected { len (batch_keys )} " )
302379
303- # TODO : Use MooncakeStore provided ret codes to detect transmission failures when supported
380+ # FIXME : Use MooncakeStore provided ret codes to detect transmission failures when supported
304381 # Currently we rely on empty bytes (b'') to detect transmission failures because
305382 # MooncakeStore does not currently return a separate status code per key.
306383 failed_indices = [i for i , result in enumerate (raw_results ) if result == b"" ]
@@ -326,8 +403,11 @@ def _get_bytes_thread_worker(self, batch_keys: list[str], indexes: list[int]) ->
326403 raw_results [original_idx ] = result
327404
328405 if not next_failed_indices :
406+ logger .info ("get_batch succeeded after retransmission." )
329407 break # All retries in this attempt succeeded.
330408
409+ logger .error (f"get_batch retry { attempt } /{ MAX_RETRIES } failed for { len (next_failed_keys )} keys." )
410+
331411 # Narrow down to still-failed items for the next retry attempt.
332412 current_failed_keys = next_failed_keys
333413 current_failed_indices = next_failed_indices
@@ -336,10 +416,8 @@ def _get_bytes_thread_worker(self, batch_keys: list[str], indexes: list[int]) ->
336416 time .sleep (RETRY_DELAY_SECONDS )
337417 else :
338418 # All retries exhausted.
339- # FIXME: raise error here when we can distinguish transmission failures from empty values
340- logger .error (
341- f"get_batch failed for keys { current_failed_keys } after retrying { MAX_RETRIES } times. "
342- f"Please validate if the values corresponding to these keys are `None` during put."
419+ raise RuntimeError (
420+ f"get_batch failed for keys { current_failed_keys } after retrying { MAX_RETRIES } times."
343421 )
344422
345423 deserialized_results = [pickle .loads (result ) if result != b"" else None for result in raw_results ]
0 commit comments