Skip to content
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -354,6 +354,7 @@ Options:
| create_stream_shards | 1 | Sets the amount of shard you want for your new stream. Note if stream already existing it will ignore |
| describe_timeout | 60 | Timeout in seconds for waiting for stream to become ACTIVE during startup. Increase for slow backends (e.g. LocalStack) |
| idle_timeout | 2.0 | Seconds to wait for new records before ending iteration. Controls how long `async for` blocks on an empty queue before raising `StopAsyncIteration` |
| checkpoint_interval | None | Seconds between checkpoint writes. `None` = checkpoint every batch. Set to e.g. `5.0` to reduce backend write pressure on active streams. Uses a background flusher task so checkpoints fire even during quiet periods. Mutually exclusive with `auto_checkpoint=False` on heartbeat checkpointers |
| timestamp | None | Timestamp to start reading stream from. Used with iterator type "AT_TIMESTAMP" |

#### Consumer Methods
Expand Down
53 changes: 47 additions & 6 deletions kinesis/checkpointers.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,26 +9,67 @@


class CheckPointer(Protocol):
"""Protocol for checkpointer implementations."""
"""Protocol for checkpointer implementations.

Checkpointers track processing progress per shard so that a consumer
can resume from the correct position after a restart. They also provide
shard-level locking so that multiple consumers don't process the same
shard concurrently.
"""

async def allocate(self, shard_id: str) -> Tuple[bool, Optional[str]]:
"""Allocate a shard for processing."""
"""Allocate a shard for processing.

Returns (True, sequence) if allocation succeeded. The sequence is the
last checkpointed position (None if no prior checkpoint exists).
Returns (False, None) if the shard is owned by another consumer.

Implementations must be safe to call multiple times for the same shard
(idempotent if already allocated by this consumer).
"""
...

async def deallocate(self, shard_id: str) -> None:
"""Deallocate a shard."""
"""Release a shard (e.g., on shard closure or consumer shutdown).

Must preserve the last checkpoint sequence for future consumers.
Called when a shard's iterator is exhausted (resharding) or on
consumer close(). The consumer guarantees all pending checkpoints
for this shard are flushed before deallocate() is called.
"""
...

async def checkpoint(self, shard_id: str, sequence_number: str) -> None:
"""Checkpoint progress for a shard."""
"""Record processing progress for a shard.

Called after the consumer has yielded all records up to sequence_number
to the user and the user has returned control (called __anext__ again).
At-least-once semantics: the last batch's records may be reprocessed
on restart if close() was not called after the final iteration.

Implementations should be idempotent and handle out-of-order calls
gracefully (e.g., ignore a sequence older than the current checkpoint).

If this method raises, the consumer will propagate the exception. The
checkpoint is considered not persisted; the same records may be
re-delivered on restart.
"""
...

def get_all_checkpoints(self) -> Dict[str, str]:
"""Get all checkpoints."""
"""Return all known checkpoints as {shard_id: sequence}.

Used for monitoring and status reporting. May return stale data
for eventually-consistent backends.
"""
...

async def close(self) -> None:
"""Close the checkpointer."""
"""Clean up resources and deallocate all owned shards.

Implementations should flush any pending/buffered checkpoints
before deallocating.
"""
...


Expand Down
165 changes: 139 additions & 26 deletions kinesis/consumer.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,7 @@ def __init__(
describe_timeout: int = 60,
idle_timeout: float = 2.0,
timestamp: Optional[datetime] = None,
checkpoint_interval: Optional[float] = None,
) -> None:

super(Consumer, self).__init__(
Expand Down Expand Up @@ -100,6 +101,25 @@ def __init__(

self._ready = asyncio.Event()

# Deferred checkpoints awaiting implicit ack (per-shard)
self._deferred_checkpoints: Dict[str, str] = {} # shard_id → sequence

# Checkpoint interval state
if checkpoint_interval is not None and checkpoint_interval <= 0:
raise ValueError(f"checkpoint_interval must be positive, got {checkpoint_interval}")
self.checkpoint_interval = checkpoint_interval
self._pending_checkpoints: Dict[str, str] = {} # shard_id → sequence
self._checkpoint_flusher_task: Optional[asyncio.Task] = None

# Validate mutual exclusion: checkpoint_interval + auto_checkpoint=False
if checkpoint_interval is not None and checkpointer is not None:
if getattr(checkpointer, "auto_checkpoint", True) is False:
raise ValueError(
"checkpoint_interval and auto_checkpoint=False are mutually exclusive. "
"Consumer-level debounce calls checkpointer.checkpoint() which silently "
"buffers under auto_checkpoint=False, achieving nothing."
)

# Shard management
self._last_shard_refresh = 0
self._shard_refresh_interval = 60 # Refresh shards every 60 seconds
Expand All @@ -124,6 +144,27 @@ async def close(self):
self.fetch_task.cancel()
self.fetch_task = None

# Flush deferred checkpoints (from __anext__ deferral)
if self._deferred_checkpoints and self.checkpointer:
for shard_id, sequence in self._deferred_checkpoints.items():
await self._maybe_checkpoint(shard_id, sequence)
self._deferred_checkpoints.clear()

# Cancel background flusher (triggers final flush via CancelledError)
if self._checkpoint_flusher_task is not None:
self._checkpoint_flusher_task.cancel()
try:
await self._checkpoint_flusher_task
except asyncio.CancelledError:
pass
self._checkpoint_flusher_task = None

# Final flush of any remaining interval-buffered checkpoints.
# Needed when the flusher task was cancelled before it started
# executing (CancelledError handler in _checkpoint_flusher never ran).
if self.checkpointer and self._pending_checkpoints:
await self._flush_pending_checkpoints()

if self.checkpointer:
await self.checkpointer.close()
if self.client is not None:
Expand Down Expand Up @@ -242,14 +283,21 @@ async def fetch(self):
log.debug("Shard {} got {} records".format(shard["ShardId"], len(records)))

total_items = 0
last_enqueued_sequence = None
for row in result["Records"]:
for n, output in enumerate(self.processor.parse(row["Data"])):
item_count = 0
dropped = False
for output in self.processor.parse(row["Data"]):
try:
await asyncio.wait_for(self.queue.put(output), timeout=30.0)
item_count += 1
except asyncio.TimeoutError:
log.warning("Queue put timed out, skipping record")
continue
total_items += n + 1
dropped = True
break
if item_count > 0 and not dropped:
last_enqueued_sequence = row["SequenceNumber"]
total_items += item_count

# Get approx minutes behind..
last_arrival = records[-1].get("ApproximateArrivalTimestamp")
Expand All @@ -274,27 +322,30 @@ async def fetch(self):
)
)

# Add checkpoint record
last_record = result["Records"][-1]
try:
await asyncio.wait_for(
self.queue.put(
{
"__CHECKPOINT__": {
"ShardId": shard["ShardId"],
"SequenceNumber": last_record["SequenceNumber"],
}
}
),
timeout=30.0,
)
except asyncio.TimeoutError:
log.warning("Checkpoint queue put timed out")
# Continue without checkpoint - not critical
# Add checkpoint record — only for last *enqueued* sequence.
# Skip sentinel for exhausted shards; their terminal checkpoint
# is flushed synchronously in the deallocation block below.
if last_enqueued_sequence:
shard["LastSequenceNumber"] = last_enqueued_sequence

shard["LastSequenceNumber"] = last_record["SequenceNumber"]
if result.get("NextShardIterator"):
try:
await asyncio.wait_for(
self.queue.put(
{
"__CHECKPOINT__": {
"ShardId": shard["ShardId"],
"SequenceNumber": last_enqueued_sequence,
}
}
),
timeout=30.0,
)
except asyncio.TimeoutError:
log.warning("Checkpoint queue put timed out")

else:
last_enqueued_sequence = None
log.debug(
"Shard {} caught up, sleeping {}s".format(shard["ShardId"], self.sleep_time_no_records)
)
Expand All @@ -313,8 +364,25 @@ async def fetch(self):
if children:
log.info(f"Parent shard {shard_id} exhausted, enabling child shards: {children}")

# Deallocate the shard so other consumers can take over child shards
# Flush all pending checkpoints for this shard before deallocating
if self.checkpointer:
# Flush deferred checkpoint if it's for this shard
if shard_id in self._deferred_checkpoints:
await self._maybe_checkpoint(
shard_id, self._deferred_checkpoints.pop(shard_id)
)

# Flush interval-buffered checkpoint for this shard
if shard_id in self._pending_checkpoints:
await self.checkpointer.checkpoint(
shard_id, self._pending_checkpoints.pop(shard_id)
)

# Note: the terminal batch's records are enqueued but no
# checkpoint sentinel was added (would race with deallocate).
# Records will be processed normally; the batch replays on
# restart (at-least-once safe, no data loss).

await self.checkpointer.deallocate(shard_id)

# Remove shard iterator to stop fetching from this shard
Expand Down Expand Up @@ -767,6 +835,44 @@ def is_ready(self) -> bool:
"""Non-blocking check whether consumer has obtained shard iterators."""
return self._ready.is_set()

async def _maybe_checkpoint(self, shard_id: str, sequence: str):
"""Commit a checkpoint, either immediately or via interval buffer."""
if self.checkpoint_interval is None:
# No debouncing — checkpoint immediately
await self.checkpointer.checkpoint(shard_id, sequence)
return

# Buffer for background flush
self._pending_checkpoints[shard_id] = sequence

# Start flusher on first use
if self._checkpoint_flusher_task is None:
self._checkpoint_flusher_task = asyncio.ensure_future(self._checkpoint_flusher())

async def _checkpoint_flusher(self):
"""Background task that flushes buffered checkpoints periodically."""
try:
while True:
await asyncio.sleep(self.checkpoint_interval)
try:
await self._flush_pending_checkpoints()
except Exception:
log.exception("Error flushing checkpoints, will retry next interval")
except asyncio.CancelledError:
# Final flush on shutdown — propagate errors so close() knows
await self._flush_pending_checkpoints()

async def _flush_pending_checkpoints(self):
"""Flush all interval-buffered checkpoints to the backend.

Entries are removed individually on success so that a failure
mid-loop preserves the remaining (unflushed) checkpoints for
the next attempt.
"""
for shard_id in list(self._pending_checkpoints):
await self.checkpointer.checkpoint(shard_id, self._pending_checkpoints[shard_id])
del self._pending_checkpoints[shard_id]

async def __anext__(self):

if not self.shards:
Expand All @@ -782,9 +888,16 @@ async def __anext__(self):
if exception:
raise exception

# 1. Commit deferred checkpoints from previous iteration (implicit ACK)
if self._deferred_checkpoints:
for shard_id, sequence in self._deferred_checkpoints.items():
await self._maybe_checkpoint(shard_id, sequence)
self._deferred_checkpoints.clear()

checkpoint_count = 0
max_checkpoints = 100 # Prevent infinite checkpoint processing

# 2. Get items from queue, deferring any checkpoint sentinels
while True:
try:
item = await asyncio.wait_for(self.queue.get(), timeout=self.idle_timeout)
Expand All @@ -794,10 +907,10 @@ async def __anext__(self):

if item and isinstance(item, dict) and "__CHECKPOINT__" in item:
if self.checkpointer:
await self.checkpointer.checkpoint(
item["__CHECKPOINT__"]["ShardId"],
item["__CHECKPOINT__"]["SequenceNumber"],
)
# Don't execute now — defer to next __anext__ call
self._deferred_checkpoints[item["__CHECKPOINT__"]["ShardId"]] = item[
"__CHECKPOINT__"
]["SequenceNumber"]
checkpoint_count += 1
if checkpoint_count >= max_checkpoints:
log.warning(f"Processed {max_checkpoints} checkpoints, stopping iteration")
Expand Down
17 changes: 15 additions & 2 deletions kinesis/testing.py
Original file line number Diff line number Diff line change
Expand Up @@ -305,6 +305,7 @@ def __init__(
self._positions: Dict[int, int] = {}
self._exhausted: set = set()
self._buffer: deque = deque()
self._deferred_checkpoints: Dict[str, str] = {} # shard_id → sequence

@property
def stream(self) -> MemoryStream:
Expand Down Expand Up @@ -340,16 +341,26 @@ async def __aenter__(self) -> "MockConsumer":
return self

async def __aexit__(self, exc_type: Any, exc: Any, tb: Any) -> None:
await self.checkpointer.close()
await self.close()

def __aiter__(self) -> AsyncIterator[Any]:
return self

async def _flush_deferred_checkpoints(self) -> None:
"""Commit deferred checkpoints (implicit ack from previous iteration)."""
for shard_id, seq in self._deferred_checkpoints.items():
await self.checkpointer.checkpoint(shard_id, seq)
self._deferred_checkpoints.clear()

async def __anext__(self) -> Any:
# Drain buffer from a previous multi-item parse (aggregated records)
if self._buffer:
return self._buffer.popleft()

# Commit deferred checkpoints from previous iteration (implicit ack)
if self._deferred_checkpoints:
await self._flush_deferred_checkpoints()

stream = self.stream

while True:
Expand All @@ -370,7 +381,7 @@ async def __anext__(self) -> Any:
continue

seq, data, partition_key = record
await self.checkpointer.checkpoint(shard.shard_id, seq)
self._deferred_checkpoints[shard.shard_id] = seq

items = list(self.processor.parse(data))
if items:
Expand All @@ -386,6 +397,8 @@ async def __anext__(self) -> Any:
await asyncio.sleep(self._poll_delay if self._poll_delay > 0 else 0)

async def close(self) -> None:
if self._deferred_checkpoints:
await self._flush_deferred_checkpoints()
await self.checkpointer.close()


Expand Down
Loading
Loading