Skip to content

Commit 0c32690

Browse files
committed
.
1 parent 630ad78 commit 0c32690

File tree

2 files changed

+26
-14
lines changed

2 files changed

+26
-14
lines changed

kinesis/consumer.py

Lines changed: 17 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -147,9 +147,10 @@ async def close(self):
147147

148148
# Flush deferred checkpoints (from __anext__ deferral)
149149
if self._deferred_checkpoints and self.checkpointer:
150-
for shard_id, sequence in self._deferred_checkpoints.items():
150+
for shard_id, sequence in list(self._deferred_checkpoints.items()):
151151
await self._maybe_checkpoint(shard_id, sequence)
152-
self._deferred_checkpoints.clear()
152+
if self._deferred_checkpoints.get(shard_id) == sequence:
153+
del self._deferred_checkpoints[shard_id]
153154

154155
# Cancel background flusher (triggers final flush via CancelledError)
155156
if self._checkpoint_flusher_task is not None:
@@ -373,15 +374,17 @@ async def fetch(self):
373374
if self.checkpointer:
374375
# Flush deferred checkpoint if it's for this shard
375376
if shard_id in self._deferred_checkpoints:
376-
await self._maybe_checkpoint(
377-
shard_id, self._deferred_checkpoints.pop(shard_id)
378-
)
377+
seq = self._deferred_checkpoints[shard_id]
378+
await self._maybe_checkpoint(shard_id, seq)
379+
if self._deferred_checkpoints.get(shard_id) == seq:
380+
del self._deferred_checkpoints[shard_id]
379381

380382
# Flush interval-buffered checkpoint for this shard
381383
if shard_id in self._pending_checkpoints:
382-
await self.checkpointer.checkpoint(
383-
shard_id, self._pending_checkpoints.pop(shard_id)
384-
)
384+
seq = self._pending_checkpoints[shard_id]
385+
await self.checkpointer.checkpoint(shard_id, seq)
386+
if self._pending_checkpoints.get(shard_id) == seq:
387+
del self._pending_checkpoints[shard_id]
385388

386389
# Note: the terminal batch's records are enqueued but no
387390
# checkpoint sentinel was added (would race with deallocate).
@@ -877,7 +880,9 @@ async def _flush_pending_checkpoints(self):
877880
newer sequence written by _maybe_checkpoint during the await.
878881
"""
879882
for shard_id in list(self._pending_checkpoints):
880-
seq = self._pending_checkpoints[shard_id]
883+
seq = self._pending_checkpoints.get(shard_id)
884+
if seq is None:
885+
continue
881886
await self.checkpointer.checkpoint(shard_id, seq)
882887
# Only delete if the value hasn't been superseded during the await
883888
if self._pending_checkpoints.get(shard_id) == seq:
@@ -900,9 +905,10 @@ async def __anext__(self):
900905

901906
# 1. Commit deferred checkpoints from previous iteration (implicit ACK)
902907
if self._deferred_checkpoints:
903-
for shard_id, sequence in self._deferred_checkpoints.items():
908+
for shard_id, sequence in list(self._deferred_checkpoints.items()):
904909
await self._maybe_checkpoint(shard_id, sequence)
905-
self._deferred_checkpoints.clear()
910+
if self._deferred_checkpoints.get(shard_id) == sequence:
911+
del self._deferred_checkpoints[shard_id]
906912

907913
checkpoint_count = 0
908914
max_checkpoints = 100 # Prevent infinite checkpoint processing

tests/conftest.py

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -189,12 +189,15 @@ def _factory(**kwargs):
189189

190190
yield _factory
191191

192+
teardown_errors = []
192193
for c in consumers:
193194
for task in [getattr(c, "_checkpoint_flusher_task", None), c.fetch_task]:
194195
if task is None:
195196
continue
196197
if task.done():
197-
# Surface exceptions from tasks that finished during the test
198+
# Log exceptions from tasks that finished during the test.
199+
# Don't re-raise: the test itself is responsible for observing
200+
# expected task failures (e.g. test_wait_ready_fetch_task_crash).
198201
if not task.cancelled() and task.exception():
199202
log.warning(
200203
"Task finished with error during test: %s", task.exception(), exc_info=task.exception()
@@ -205,5 +208,8 @@ def _factory(**kwargs):
205208
await task
206209
except asyncio.CancelledError:
207210
pass
208-
except Exception:
209-
log.warning("Unexpected error during mock_consumer teardown", exc_info=True)
211+
except Exception as exc:
212+
teardown_errors.append(exc)
213+
if teardown_errors:
214+
log.warning("mock_consumer teardown errors: %s", teardown_errors)
215+
raise teardown_errors[0]

0 commit comments

Comments
 (0)