-
Notifications
You must be signed in to change notification settings - Fork 9
Expand file tree
/
Copy pathcheckpointers.py
More file actions
312 lines (224 loc) · 10.1 KB
/
checkpointers.py
File metadata and controls
312 lines (224 loc) · 10.1 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
import asyncio
import json
import logging
import os
from datetime import datetime, timezone
from typing import Any, Dict, Optional, Protocol, Tuple, Union
log = logging.getLogger(__name__)
class CheckPointer(Protocol):
"""Protocol for checkpointer implementations.
Checkpointers track processing progress per shard so that a consumer
can resume from the correct position after a restart. They also provide
shard-level locking so that multiple consumers don't process the same
shard concurrently.
"""
async def allocate(self, shard_id: str) -> Tuple[bool, Optional[str]]:
"""Allocate a shard for processing.
Returns (True, sequence) if allocation succeeded. The sequence is the
last checkpointed position (None if no prior checkpoint exists).
Returns (False, None) if the shard is owned by another consumer.
Implementations must be safe to call multiple times for the same shard
(idempotent if already allocated by this consumer).
"""
...
async def deallocate(self, shard_id: str) -> None:
"""Release a shard (e.g., on shard closure or consumer shutdown).
Must preserve the last checkpoint sequence for future consumers.
Called when a shard's iterator is exhausted (resharding) or on
consumer close(). The consumer guarantees all pending checkpoints
for this shard are flushed before deallocate() is called.
"""
...
async def checkpoint(self, shard_id: str, sequence_number: str) -> None:
"""Record processing progress for a shard.
Called after the consumer has yielded all records up to sequence_number
to the user and the user has returned control (called __anext__ again).
At-least-once semantics: the last batch's records may be reprocessed
on restart if close() was not called after the final iteration.
Implementations should be idempotent and handle out-of-order calls
gracefully (e.g., ignore a sequence older than the current checkpoint).
If this method raises, the consumer will propagate the exception. The
checkpoint is considered not persisted; the same records may be
re-delivered on restart.
"""
...
def get_all_checkpoints(self) -> Dict[str, str]:
"""Return all known checkpoints as {shard_id: sequence}.
Used for monitoring and status reporting. May return stale data
for eventually-consistent backends.
"""
...
async def close(self) -> None:
"""Clean up resources and deallocate all owned shards.
Implementations should flush any pending/buffered checkpoints
before deallocating.
"""
...
class BaseCheckPointer:
def __init__(self, name: str = "", id: Optional[Union[str, int]] = None) -> None:
self._id: Union[str, int] = id if id else os.getpid()
self._name: str = name
self._items: Dict[str, Any] = {}
def get_id(self) -> Union[str, int]:
return self._id
def get_ref(self) -> str:
return "{}/{}".format(self._name, self._id)
def get_all_checkpoints(self):
return self._items.copy()
def get_checkpoint(self, shard_id):
return self._items.get(shard_id)
async def close(self):
log.info("{} stopping..".format(self.get_ref()))
await asyncio.gather(*[self.deallocate(shard_id) for shard_id in self._items.keys()])
def is_allocated(self, shard_id):
return shard_id in self._items
class BaseHeartbeatCheckPointer(BaseCheckPointer):
def __init__(
self,
name,
id=None,
session_timeout=60,
heartbeat_frequency=15,
auto_checkpoint=True,
):
super().__init__(name=name, id=id)
self.session_timeout = session_timeout
self.heartbeat_frequency = heartbeat_frequency
self.auto_checkpoint = auto_checkpoint
self._manual_checkpoints = {}
self.heartbeat_task = asyncio.Task(self.heartbeat())
async def close(self):
log.debug("Cancelling heartbeat task..")
self.heartbeat_task.cancel()
await super().close()
async def heartbeat(self):
while True:
await asyncio.sleep(self.heartbeat_frequency)
# todo: don't heartbeat if checkpoint already updated it recently
for shard_id, sequence in self._items.items():
key = self.get_key(shard_id)
val = {"ref": self.get_ref(), "ts": self.get_ts(), "sequence": sequence}
log.debug("Heartbeating {}@{}".format(shard_id, sequence))
await self.do_heartbeat(key, val)
class MemoryCheckPointer(BaseCheckPointer):
async def deallocate(self, shard_id):
log.info("{} deallocated on {}@{}".format(self.get_ref(), shard_id, self._items[shard_id]))
self._items[shard_id]["active"] = False
def is_allocated(self, shard_id):
return shard_id in self._items and self._items[shard_id]["active"]
async def allocate(self, shard_id):
if self.is_allocated(shard_id):
return False, None
if shard_id not in self._items:
self._items[shard_id] = {"sequence": None}
self._items[shard_id]["active"] = True
return True, self._items[shard_id]["sequence"]
async def checkpoint(self, shard_id, sequence):
log.debug("{} checkpointed on {} @ {}".format(self.get_ref(), shard_id, sequence))
self._items[shard_id]["sequence"] = sequence
class RedisCheckPointer(BaseHeartbeatCheckPointer):
def __init__(
self,
name,
id=None,
session_timeout=60,
heartbeat_frequency=15,
is_cluster=False,
auto_checkpoint=True,
):
super().__init__(
name=name,
id=id,
session_timeout=session_timeout,
heartbeat_frequency=heartbeat_frequency,
auto_checkpoint=auto_checkpoint,
)
if is_cluster:
from redis.asyncio.cluster import RedisCluster as Redis
else:
from redis.asyncio import Redis
params = {
"host": os.environ.get("REDIS_HOST", "localhost"),
"port": int(os.environ.get("REDIS_PORT", "6379")),
"password": os.environ.get("REDIS_PASSWORD"),
}
if not is_cluster:
db = int(os.environ.get("REDIS_DB", 0))
if db > 0:
params["db"] = db
else:
params["skip_full_coverage_check"] = True
self.client = Redis(**params)
async def do_heartbeat(self, key, value):
await self.client.set(key, json.dumps(value))
def get_key(self, shard_id):
return "pyredis-{}-{}".format(self._name, shard_id)
def get_ts(self):
return round(int(datetime.now(tz=timezone.utc).timestamp()))
async def checkpoint(self, shard_id, sequence):
if not self.auto_checkpoint:
log.debug("{} updated manual checkpoint {}@{}".format(self.get_ref(), shard_id, sequence))
self._manual_checkpoints[shard_id] = sequence
return
await self._checkpoint(shard_id, sequence)
async def manual_checkpoint(self):
items = [(k, v) for k, v in self._manual_checkpoints.items()]
self._manual_checkpoints = {}
for shard_id, sequence in items:
await self._checkpoint(shard_id, sequence)
async def _checkpoint(self, shard_id, sequence):
key = self.get_key(shard_id)
val = {"ref": self.get_ref(), "ts": self.get_ts(), "sequence": sequence}
previous_val = await self.client.getset(key, json.dumps(val))
previous_val = json.loads(previous_val) if previous_val else None
if not previous_val:
raise NotImplementedError("{} checkpointed on {} but key did not exist?".format(self.get_ref(), shard_id))
if previous_val["ref"] != self.get_ref():
raise NotImplementedError(
"{} checkpointed on {} but ref is different {}".format(self.get_ref(), shard_id, val["ref"])
)
log.debug("{} checkpointed on {}@{}".format(self.get_ref(), shard_id, sequence))
self._items[shard_id] = sequence
async def deallocate(self, shard_id):
key = self.get_key(shard_id)
val = {"ref": None, "ts": None, "sequence": self._items[shard_id]}
await self.client.set(key, json.dumps(val))
log.info("{} deallocated on {}@{}".format(self.get_ref(), shard_id, self._items[shard_id]))
self._items.pop(shard_id)
async def allocate(self, shard_id):
key = self.get_key(shard_id)
ts = self.get_ts()
# try to set lock
success = await self.client.set(
key,
json.dumps({"ref": self.get_ref(), "ts": ts, "sequence": None}),
nx=True,
)
val = await self.client.get(key)
val = json.loads(val) if val else None
original_ts = val["ts"]
if success:
log.info("{} allocated {} (new checkpoint)".format(self.get_ref(), shard_id))
self._items[shard_id] = None
return True, None
if val["ts"]:
log.info("{} could not allocate {}, still in use by {}".format(self.get_ref(), shard_id, val["ref"]))
# Wait a bit before carrying on to avoid spamming ourselves
await asyncio.sleep(1)
age = ts - original_ts
# still alive?
if age < self.session_timeout:
return False, None
log.info(
"Attempting to take lock as {} is {} seconds over due..".format(val["ref"], age - self.session_timeout)
)
val["ref"] = self.get_ref()
val["ts"] = ts
previous_val = await self.client.getset(key, json.dumps(val))
previous_val = json.loads(previous_val) if previous_val else None
if previous_val["ts"] != original_ts:
log.info("{} beat me to the lock..".format(previous_val["ref"]))
return False, None
log.info("{} allocating {}@{}".format(self.get_ref(), shard_id, val["sequence"]))
self._items[shard_id] = val["sequence"]
return True, val["sequence"]