Skip to content

Commit 630ad78

Browse files
committed
.
1 parent fe6a567 commit 630ad78

4 files changed

Lines changed: 143 additions & 7 deletions

File tree

docs/DESIGN.md

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,20 @@ Unlike earlier versions, the consumer now fully supports dynamic resharding:
4444
- **Rate limiting**: Configurable per-shard bandwidth and record rate limits
4545
- **Async buffering**: Non-blocking `put()` operations with configurable queue sizes
4646

47+
### Checkpoint Safety
48+
49+
Checkpoints use a **deferred execution** model to prevent data loss:
50+
51+
1. **Deferred commit**: When a `__CHECKPOINT__` sentinel is dequeued from the internal queue, it is stored as pending but not committed. The checkpoint only fires at the start of the *next* `__anext__()` call, proving the user's code survived processing the preceding records. If the consumer crashes between receiving a record and calling `__anext__()` again, the checkpoint is never committed and records replay on restart (at-least-once).
52+
53+
2. **Queue put timeout**: If enqueueing a parsed record times out (bounded queue full for 30s), `LastSequenceNumber` only advances to the last fully-enqueued record. The remaining rows in the Kinesis batch are abandoned to prevent a non-contiguous sequence gap that would skip records on restart.
54+
55+
3. **Shard deallocation ordering**: When a shard iterator is exhausted (`NextShardIterator=None`), all pending checkpoints for that shard are flushed *before* `deallocate()` releases ownership. No checkpoint sentinel is enqueued for the terminal batch (it would race with deallocation); instead, those records replay on restart. Checkpoint sentinels that were already queued before deallocation are silently skipped via a `_deallocated_shards` set.
56+
57+
4. **`checkpoint_interval` debouncing**: When set, checkpoint writes are buffered in `_pending_checkpoints` and flushed by a background task every N seconds, reducing backend write pressure. The flusher uses compare-and-delete to avoid dropping a newer sequence that arrives during the `await` on the checkpoint backend. On `close()`, deferred checkpoints are committed, the flusher is cancelled (triggering a final flush), and any remaining buffered checkpoints are flushed before the checkpointer is closed.
58+
59+
These guarantees hold under single-process asyncio concurrency. For multi-process coordination, the `CheckPointer` implementation (e.g. Redis with locking) must handle ownership contention.
60+
4761
## Integration Points
4862

4963
- **Checkpointing**: Pluggable checkpointer interface (Memory, Redis) for multi-consumer coordination

kinesis/consumer.py

Lines changed: 20 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -128,6 +128,7 @@ def __init__(
128128
self._parent_shards = set() # Track shards that are parents
129129
self._child_shards = set() # Track shards that are children
130130
self._exhausted_parents = set() # Track parent shards that are fully consumed
131+
self._deallocated_shards: set = set() # Shards already deallocated (skip stale sentinels)
131132

132133
def __aiter__(self) -> AsyncIterator[Any]:
133134
return self
@@ -295,7 +296,11 @@ async def fetch(self):
295296
log.warning("Queue put timed out, skipping record")
296297
dropped = True
297298
break
298-
if item_count > 0 and not dropped:
299+
if dropped:
300+
# Stop processing this batch — remaining rows would
301+
# create a non-contiguous sequence gap on restart.
302+
break
303+
if item_count > 0:
299304
last_enqueued_sequence = row["SequenceNumber"]
300305
total_items += item_count
301306

@@ -384,6 +389,7 @@ async def fetch(self):
384389
# restart (at-least-once safe, no data loss).
385390

386391
await self.checkpointer.deallocate(shard_id)
392+
self._deallocated_shards.add(shard_id)
387393

388394
# Remove shard iterator to stop fetching from this shard
389395
shard.pop("ShardIterator", None)
@@ -867,11 +873,15 @@ async def _flush_pending_checkpoints(self):
867873
868874
Entries are removed individually on success so that a failure
869875
mid-loop preserves the remaining (unflushed) checkpoints for
870-
the next attempt.
876+
the next attempt. Uses compare-and-delete to avoid dropping a
877+
newer sequence written by _maybe_checkpoint during the await.
871878
"""
872879
for shard_id in list(self._pending_checkpoints):
873-
await self.checkpointer.checkpoint(shard_id, self._pending_checkpoints[shard_id])
874-
del self._pending_checkpoints[shard_id]
880+
seq = self._pending_checkpoints[shard_id]
881+
await self.checkpointer.checkpoint(shard_id, seq)
882+
# Only delete if the value hasn't been superseded during the await
883+
if self._pending_checkpoints.get(shard_id) == seq:
884+
del self._pending_checkpoints[shard_id]
875885

876886
async def __anext__(self):
877887

@@ -906,9 +916,13 @@ async def __anext__(self):
906916
raise StopAsyncIteration from None
907917

908918
if item and isinstance(item, dict) and "__CHECKPOINT__" in item:
909-
if self.checkpointer:
919+
cp_shard = item["__CHECKPOINT__"]["ShardId"]
920+
if cp_shard in self._deallocated_shards:
921+
# Stale sentinel queued before deallocation; skip silently
922+
log.debug("Skipping checkpoint for deallocated shard %s", cp_shard)
923+
elif self.checkpointer:
910924
# Don't execute now — defer to next __anext__ call
911-
self._deferred_checkpoints[item["__CHECKPOINT__"]["ShardId"]] = item[
925+
self._deferred_checkpoints[cp_shard] = item[
912926
"__CHECKPOINT__"
913927
]["SequenceNumber"]
914928
checkpoint_count += 1

tests/conftest.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -191,7 +191,15 @@ def _factory(**kwargs):
191191

192192
for c in consumers:
193193
for task in [getattr(c, "_checkpoint_flusher_task", None), c.fetch_task]:
194-
if task and not task.done():
194+
if task is None:
195+
continue
196+
if task.done():
197+
# Surface exceptions from tasks that finished during the test
198+
if not task.cancelled() and task.exception():
199+
log.warning(
200+
"Task finished with error during test: %s", task.exception(), exc_info=task.exception()
201+
)
202+
else:
195203
task.cancel()
196204
try:
197205
await task

tests/test_checkpoint_ordering.py

Lines changed: 100 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -141,6 +141,58 @@ async def test_no_checkpointer_skips_checkpoint_processing(self, mock_consumer):
141141
class TestQueuePutTimeoutFix:
142142
"""Verify LastSequenceNumber only advances to last successfully enqueued record."""
143143

144+
@pytest.mark.asyncio
145+
async def test_queue_put_timeout_breaks_outer_loop(self, mock_consumer):
146+
"""When queue.put() times out on row 2 of a 3-row batch, row 3 is NOT
147+
processed — prevents a non-contiguous sequence gap on restart."""
148+
consumer = mock_consumer(sleep_time_no_records=0)
149+
consumer.checkpointer = MemoryCheckPointer(name="test")
150+
await consumer.checkpointer.allocate("shard-0")
151+
consumer.refresh_shards = AsyncMock()
152+
consumer.get_records = AsyncMock(return_value=None)
153+
154+
# Fail on the second row's put (row 2), succeed on row 1
155+
put_count = 0
156+
original_put = consumer.queue.put
157+
158+
async def failing_put(item):
159+
nonlocal put_count
160+
put_count += 1
161+
if put_count == 2: # Row 2's first output
162+
raise asyncio.TimeoutError("simulated queue full")
163+
await original_put(item)
164+
165+
consumer.queue.put = failing_put
166+
167+
shard = consumer.shards[0]
168+
shard["ShardIterator"] = "iter-0"
169+
shard["stats"] = ShardStats()
170+
shard["throttler"] = Throttler(rate_limit=1, period=1)
171+
172+
fetch_result = {
173+
"Records": [
174+
{"SequenceNumber": "100", "Data": b'{"msg": "r1"}'},
175+
{"SequenceNumber": "200", "Data": b'{"msg": "r2"}'},
176+
{"SequenceNumber": "300", "Data": b'{"msg": "r3"}'},
177+
],
178+
"NextShardIterator": "iter-next",
179+
}
180+
fut = asyncio.get_running_loop().create_future()
181+
fut.set_result(fetch_result)
182+
shard["fetch"] = fut
183+
184+
await consumer.fetch()
185+
186+
# Row 1 succeeded, row 2 dropped → outer loop breaks, row 3 never tried
187+
assert shard.get("LastSequenceNumber") == "100"
188+
# Only 1 data item should be in the queue (row 1)
189+
items = []
190+
while not consumer.queue.empty():
191+
items.append(consumer.queue.get_nowait())
192+
data_items = [i for i in items if isinstance(i, dict) and "msg" in i]
193+
assert len(data_items) == 1
194+
assert data_items[0]["msg"] == "r1"
195+
144196
@pytest.mark.asyncio
145197
async def test_queue_put_timeout_no_sequence_advance(self, mock_consumer):
146198
"""When queue.put() times out mid-batch, LastSequenceNumber tracks only
@@ -272,6 +324,28 @@ async def test_shard_exhaustion_with_records_no_sentinel_enqueued(self, mock_con
272324
assert len(data_items) == 2
273325
assert checkpoint_items == [], "No sentinel for terminal batch (would race with deallocate)"
274326

327+
@pytest.mark.asyncio
328+
async def test_stale_sentinel_skipped_after_deallocation(self, mock_consumer):
329+
"""A checkpoint sentinel queued before deallocation is silently skipped
330+
in __anext__ rather than checkpointing a deallocated shard."""
331+
consumer = mock_consumer()
332+
checkpointer = _make_mock_checkpointer()
333+
consumer.checkpointer = checkpointer
334+
335+
# Simulate shard-0 already deallocated (e.g. shard exhaustion ran)
336+
consumer._deallocated_shards.add("shard-0")
337+
338+
# Stale sentinel arrives for deallocated shard, then real data
339+
await consumer.queue.put({"__CHECKPOINT__": {"ShardId": "shard-0", "SequenceNumber": "100"}})
340+
await consumer.queue.put({"msg": "from-other-shard"})
341+
342+
item = await consumer.__anext__()
343+
assert item == {"msg": "from-other-shard"}
344+
345+
# Sentinel must NOT have been deferred or committed
346+
assert "shard-0" not in consumer._deferred_checkpoints
347+
checkpointer.checkpoint.assert_not_awaited()
348+
275349
@pytest.mark.asyncio
276350
async def test_close_flushes_then_deallocates(self, mock_consumer):
277351
"""close() flushes all pending checkpoints before deallocating shards."""
@@ -395,6 +469,32 @@ async def test_checkpoint_interval_with_auto_checkpoint_false_raises(self):
395469
checkpointer=AsyncMock(auto_checkpoint=False),
396470
)
397471

472+
@pytest.mark.asyncio
473+
async def test_flush_preserves_newer_sequence_written_during_await(self, mock_consumer):
474+
"""If _maybe_checkpoint writes a newer sequence for a shard while
475+
_flush_pending_checkpoints is awaiting the backend, the newer value
476+
must survive (compare-and-delete, not unconditional delete)."""
477+
consumer = mock_consumer(checkpoint_interval=60.0)
478+
consumer.checkpointer = _make_mock_checkpointer()
479+
480+
# Simulate: flusher reads seq "100" for shard-0, then during the
481+
# checkpoint() await, _maybe_checkpoint overwrites with "200".
482+
original_checkpoint = consumer.checkpointer.checkpoint
483+
484+
async def checkpoint_with_concurrent_write(shard_id, seq):
485+
# Simulate _maybe_checkpoint writing a newer value during this await
486+
if seq == "100":
487+
consumer._pending_checkpoints["shard-0"] = "200"
488+
await original_checkpoint(shard_id, seq)
489+
490+
consumer.checkpointer.checkpoint = AsyncMock(side_effect=checkpoint_with_concurrent_write)
491+
492+
consumer._pending_checkpoints["shard-0"] = "100"
493+
await consumer._flush_pending_checkpoints()
494+
495+
# "100" was flushed but "200" must still be pending (not deleted)
496+
assert consumer._pending_checkpoints.get("shard-0") == "200"
497+
398498
@pytest.mark.asyncio
399499
async def test_checkpoint_backend_raises_during_flush(self, mock_consumer):
400500
"""Exception from checkpoint backend during flush propagates from close()."""

0 commit comments

Comments
 (0)