@@ -147,9 +147,10 @@ async def close(self):
147147
148148 # Flush deferred checkpoints (from __anext__ deferral)
149149 if self ._deferred_checkpoints and self .checkpointer :
150- for shard_id , sequence in self ._deferred_checkpoints .items ():
150+ for shard_id , sequence in list ( self ._deferred_checkpoints .items () ):
151151 await self ._maybe_checkpoint (shard_id , sequence )
152- self ._deferred_checkpoints .clear ()
152+ if self ._deferred_checkpoints .get (shard_id ) == sequence :
153+ del self ._deferred_checkpoints [shard_id ]
153154
154155 # Cancel background flusher (triggers final flush via CancelledError)
155156 if self ._checkpoint_flusher_task is not None :
@@ -373,15 +374,17 @@ async def fetch(self):
373374 if self .checkpointer :
374375 # Flush deferred checkpoint if it's for this shard
375376 if shard_id in self ._deferred_checkpoints :
376- await self ._maybe_checkpoint (
377- shard_id , self ._deferred_checkpoints .pop (shard_id )
378- )
377+ seq = self ._deferred_checkpoints [shard_id ]
378+ await self ._maybe_checkpoint (shard_id , seq )
379+ if self ._deferred_checkpoints .get (shard_id ) == seq :
380+ del self ._deferred_checkpoints [shard_id ]
379381
380382 # Flush interval-buffered checkpoint for this shard
381383 if shard_id in self ._pending_checkpoints :
382- await self .checkpointer .checkpoint (
383- shard_id , self ._pending_checkpoints .pop (shard_id )
384- )
384+ seq = self ._pending_checkpoints [shard_id ]
385+ await self .checkpointer .checkpoint (shard_id , seq )
386+ if self ._pending_checkpoints .get (shard_id ) == seq :
387+ del self ._pending_checkpoints [shard_id ]
385388
386389 # Note: the terminal batch's records are enqueued but no
387390 # checkpoint sentinel was added (would race with deallocate).
@@ -877,7 +880,9 @@ async def _flush_pending_checkpoints(self):
877880 newer sequence written by _maybe_checkpoint during the await.
878881 """
879882 for shard_id in list (self ._pending_checkpoints ):
880- seq = self ._pending_checkpoints [shard_id ]
883+ seq = self ._pending_checkpoints .get (shard_id )
884+ if seq is None :
885+ continue
881886 await self .checkpointer .checkpoint (shard_id , seq )
882887 # Only delete if the value hasn't been superseded during the await
883888 if self ._pending_checkpoints .get (shard_id ) == seq :
@@ -900,9 +905,10 @@ async def __anext__(self):
900905
901906 # 1. Commit deferred checkpoints from previous iteration (implicit ACK)
902907 if self ._deferred_checkpoints :
903- for shard_id , sequence in self ._deferred_checkpoints .items ():
908+ for shard_id , sequence in list ( self ._deferred_checkpoints .items () ):
904909 await self ._maybe_checkpoint (shard_id , sequence )
905- self ._deferred_checkpoints .clear ()
910+ if self ._deferred_checkpoints .get (shard_id ) == sequence :
911+ del self ._deferred_checkpoints [shard_id ]
906912
907913 checkpoint_count = 0
908914 max_checkpoints = 100 # Prevent infinite checkpoint processing
0 commit comments