Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
53 changes: 47 additions & 6 deletions kinesis/checkpointers.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,26 +9,67 @@


class CheckPointer(Protocol):
"""Protocol for checkpointer implementations."""
"""Protocol for checkpointer implementations.

Checkpointers track processing progress per shard so that a consumer
can resume from the correct position after a restart. They also provide
shard-level locking so that multiple consumers don't process the same
shard concurrently.
"""

async def allocate(self, shard_id: str) -> Tuple[bool, Optional[str]]:
"""Allocate a shard for processing."""
"""Allocate a shard for processing.

Returns (True, sequence) if allocation succeeded. The sequence is the
last checkpointed position (None if no prior checkpoint exists).
Returns (False, None) if the shard is owned by another consumer.

Implementations must be safe to call multiple times for the same shard
(idempotent if already allocated by this consumer).
"""
...

async def deallocate(self, shard_id: str) -> None:
"""Deallocate a shard."""
"""Release a shard (e.g., on shard closure or consumer shutdown).

Must preserve the last checkpoint sequence for future consumers.
Called when a shard's iterator is exhausted (resharding) or on
consumer close(). The consumer guarantees all pending checkpoints
for this shard are flushed before deallocate() is called.
"""
...

async def checkpoint(self, shard_id: str, sequence_number: str) -> None:
"""Checkpoint progress for a shard."""
"""Record processing progress for a shard.

Called after the consumer has yielded all records up to sequence_number
to the user and the user has returned control (called __anext__ again).
At-least-once semantics: the last batch's records may be reprocessed
on restart if close() was not called after the final iteration.

Implementations should be idempotent and handle out-of-order calls
gracefully (e.g., ignore a sequence older than the current checkpoint).

If this method raises, the consumer will propagate the exception. The
checkpoint is considered not persisted; the same records may be
re-delivered on restart.
"""
...

def get_all_checkpoints(self) -> Dict[str, str]:
"""Get all checkpoints."""
"""Return all known checkpoints as {shard_id: sequence}.

Used for monitoring and status reporting. May return stale data
for eventually-consistent backends.
"""
...

async def close(self) -> None:
"""Close the checkpointer."""
"""Clean up resources and deallocate all owned shards.

Implementations should flush any pending/buffered checkpoints
before deallocating.
"""
...


Expand Down
145 changes: 123 additions & 22 deletions kinesis/consumer.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,7 @@ def __init__(
describe_timeout: int = 60,
idle_timeout: float = 2.0,
timestamp: Optional[datetime] = None,
checkpoint_interval: Optional[float] = None,
) -> None:

super(Consumer, self).__init__(
Expand Down Expand Up @@ -100,6 +101,23 @@ def __init__(

self._ready = asyncio.Event()

# Deferred checkpoints awaiting implicit ack (per-shard)
self._deferred_checkpoints: Dict[str, str] = {} # shard_id → sequence

# Checkpoint interval state (Change 2: checkpoint debouncing)
self.checkpoint_interval = checkpoint_interval
self._pending_checkpoints: Dict[str, str] = {} # shard_id → sequence
self._checkpoint_flusher_task: Optional[asyncio.Task] = None

# Validate mutual exclusion: checkpoint_interval + auto_checkpoint=False
if checkpoint_interval is not None and checkpointer is not None:
if getattr(checkpointer, "auto_checkpoint", True) is False:
raise ValueError(
"checkpoint_interval and auto_checkpoint=False are mutually exclusive. "
"Consumer-level debounce calls checkpointer.checkpoint() which silently "
"buffers under auto_checkpoint=False, achieving nothing."
)

# Shard management
self._last_shard_refresh = 0
self._shard_refresh_interval = 60 # Refresh shards every 60 seconds
Expand All @@ -124,6 +142,27 @@ async def close(self):
self.fetch_task.cancel()
self.fetch_task = None

# Flush deferred checkpoints (from __anext__ deferral)
if self._deferred_checkpoints and self.checkpointer:
for shard_id, sequence in self._deferred_checkpoints.items():
await self._maybe_checkpoint(shard_id, sequence)
self._deferred_checkpoints.clear()

# Cancel background flusher (triggers final flush via CancelledError)
if self._checkpoint_flusher_task is not None:
self._checkpoint_flusher_task.cancel()
try:
await self._checkpoint_flusher_task
except asyncio.CancelledError:
pass
self._checkpoint_flusher_task = None

# Final flush of any remaining interval-buffered checkpoints.
# Needed when the flusher task was cancelled before it started
# executing (CancelledError handler in _checkpoint_flusher never ran).
if self.checkpointer and self._pending_checkpoints:
await self._flush_pending_checkpoints()

if self.checkpointer:
await self.checkpointer.close()
if self.client is not None:
Expand Down Expand Up @@ -242,13 +281,18 @@ async def fetch(self):
log.debug("Shard {} got {} records".format(shard["ShardId"], len(records)))

total_items = 0
last_enqueued_sequence = None
for row in result["Records"]:
enqueued_any = False
for n, output in enumerate(self.processor.parse(row["Data"])):
try:
await asyncio.wait_for(self.queue.put(output), timeout=30.0)
enqueued_any = True
except asyncio.TimeoutError:
log.warning("Queue put timed out, skipping record")
continue
if enqueued_any:
last_enqueued_sequence = row["SequenceNumber"]
total_items += n + 1

# Get approx minutes behind..
Expand All @@ -274,25 +318,25 @@ async def fetch(self):
)
)

# Add checkpoint record
last_record = result["Records"][-1]
try:
await asyncio.wait_for(
self.queue.put(
{
"__CHECKPOINT__": {
"ShardId": shard["ShardId"],
"SequenceNumber": last_record["SequenceNumber"],
# Add checkpoint record — only for last *enqueued* sequence
if last_enqueued_sequence:
try:
await asyncio.wait_for(
self.queue.put(
{
"__CHECKPOINT__": {
"ShardId": shard["ShardId"],
"SequenceNumber": last_enqueued_sequence,
}
}
}
),
timeout=30.0,
)
except asyncio.TimeoutError:
log.warning("Checkpoint queue put timed out")
# Continue without checkpoint - not critical
),
timeout=30.0,
)
except asyncio.TimeoutError:
log.warning("Checkpoint queue put timed out")
# Continue without checkpoint - not critical

shard["LastSequenceNumber"] = last_record["SequenceNumber"]
shard["LastSequenceNumber"] = last_enqueued_sequence

else:
log.debug(
Expand All @@ -313,8 +357,20 @@ async def fetch(self):
if children:
log.info(f"Parent shard {shard_id} exhausted, enabling child shards: {children}")

# Deallocate the shard so other consumers can take over child shards
# Flush pending checkpoints for this shard before deallocating
if self.checkpointer:
# Flush deferred checkpoint if it's for this shard
if shard_id in self._deferred_checkpoints:
await self._maybe_checkpoint(
shard_id, self._deferred_checkpoints.pop(shard_id)
)

# Flush interval-buffered checkpoint for this shard
if shard_id in self._pending_checkpoints:
await self.checkpointer.checkpoint(
shard_id, self._pending_checkpoints.pop(shard_id)
)

await self.checkpointer.deallocate(shard_id)

# Remove shard iterator to stop fetching from this shard
Expand Down Expand Up @@ -767,6 +823,44 @@ def is_ready(self) -> bool:
"""Non-blocking check whether consumer has obtained shard iterators."""
return self._ready.is_set()

async def _maybe_checkpoint(self, shard_id: str, sequence: str):
"""Commit a checkpoint, either immediately or via interval buffer."""
if self.checkpoint_interval is None:
# No debouncing — checkpoint immediately
await self.checkpointer.checkpoint(shard_id, sequence)
return

# Buffer for background flush
self._pending_checkpoints[shard_id] = sequence

# Start flusher on first use
if self._checkpoint_flusher_task is None:
self._checkpoint_flusher_task = asyncio.ensure_future(self._checkpoint_flusher())

async def _checkpoint_flusher(self):
"""Background task that flushes buffered checkpoints periodically."""
try:
while True:
await asyncio.sleep(self.checkpoint_interval)
try:
await self._flush_pending_checkpoints()
except Exception:
log.exception("Error flushing checkpoints, will retry next interval")
except asyncio.CancelledError:
# Final flush on shutdown — propagate errors so close() knows
await self._flush_pending_checkpoints()

async def _flush_pending_checkpoints(self):
"""Flush all interval-buffered checkpoints to the backend.

Entries are removed individually on success so that a failure
mid-loop preserves the remaining (unflushed) checkpoints for
the next attempt.
"""
for shard_id in list(self._pending_checkpoints):
await self.checkpointer.checkpoint(shard_id, self._pending_checkpoints[shard_id])
del self._pending_checkpoints[shard_id]

async def __anext__(self):

if not self.shards:
Expand All @@ -782,9 +876,16 @@ async def __anext__(self):
if exception:
raise exception

# 1. Commit deferred checkpoints from previous iteration (implicit ACK)
if self._deferred_checkpoints:
for shard_id, sequence in self._deferred_checkpoints.items():
await self._maybe_checkpoint(shard_id, sequence)
self._deferred_checkpoints.clear()

checkpoint_count = 0
max_checkpoints = 100 # Prevent infinite checkpoint processing

# 2. Get items from queue, deferring any checkpoint sentinels
while True:
try:
item = await asyncio.wait_for(self.queue.get(), timeout=self.idle_timeout)
Expand All @@ -794,10 +895,10 @@ async def __anext__(self):

if item and isinstance(item, dict) and "__CHECKPOINT__" in item:
if self.checkpointer:
await self.checkpointer.checkpoint(
item["__CHECKPOINT__"]["ShardId"],
item["__CHECKPOINT__"]["SequenceNumber"],
)
# Don't execute now — defer to next __anext__ call
self._deferred_checkpoints[item["__CHECKPOINT__"]["ShardId"]] = item[
"__CHECKPOINT__"
]["SequenceNumber"]
checkpoint_count += 1
if checkpoint_count >= max_checkpoints:
log.warning(f"Processed {max_checkpoints} checkpoints, stopping iteration")
Expand Down
52 changes: 52 additions & 0 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -150,3 +150,55 @@ def _random_string(length):
)

skip_if_no_redis = pytest.mark.skipif(not os.environ.get("REDIS_HOST"), reason="Redis not available (set REDIS_HOST)")


# --- Mock consumer for unit tests (no Docker) ---


def _make_mock_consumer(**kwargs):
"""Create a Consumer with mocked internals for unit testing (no Docker)."""
defaults = {
"stream_name": "test-stream",
"endpoint_url": "http://localhost:4567",
"skip_describe_stream": True,
"idle_timeout": 0.5,
}
defaults.update(kwargs)

c = Consumer(**defaults)
c.stream_status = c.ACTIVE
c.shards = [{"ShardId": "shard-0"}]

# Mock fetch task as a no-op running task
c.fetch_task = asyncio.ensure_future(asyncio.sleep(3600))

return c


@pytest_asyncio.fixture
async def mock_consumer():
"""Fixture that creates mock consumers and guarantees cleanup even on assertion failure."""
consumers = []

def _factory(**kwargs):
c = _make_mock_consumer(**kwargs)
consumers.append(c)
return c

yield _factory

for c in consumers:
if getattr(c, "_checkpoint_flusher_task", None):
if not c._checkpoint_flusher_task.done():
c._checkpoint_flusher_task.cancel()
try:
await c._checkpoint_flusher_task
except (asyncio.CancelledError, Exception):
pass
if c.fetch_task:
if not c.fetch_task.done():
c.fetch_task.cancel()
try:
await c.fetch_task
except (asyncio.CancelledError, Exception):
pass
Loading
Loading