Skip to content
This repository was archived by the owner on Apr 26, 2024. It is now read-only.

Commit db1cfe9

Browse files
authored
Update all stream IDs after processing replication rows (#14723)
This creates a new store method, `process_replication_position` that is called after `process_replication_rows`. By moving stream ID advances here this guarantees any relevant cache invalidations will have been applied before the stream is advanced. This avoids race conditions where Python switches between threads mid way through processing the `process_replication_rows` method where stream IDs may be advanced before caches are invalidated due to class resolution ordering. See this comment/issue for further discussion: #14158 (comment)
1 parent c445611 commit db1cfe9

File tree

13 files changed

+95
-20
lines changed

13 files changed

+95
-20
lines changed

changelog.d/14723.bugfix

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
Ensure stream IDs are always updated after caches get invalidated with workers. Contributed by Nick @ Beeper (@fizzadar).

synapse/replication/tcp/client.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -152,6 +152,9 @@ async def on_rdata(
152152
rows: a list of Stream.ROW_TYPE objects as returned by Stream.parse_row.
153153
"""
154154
self.store.process_replication_rows(stream_name, instance_name, token, rows)
155+
# NOTE: this must be called after process_replication_rows to ensure any
156+
# cache invalidations are first handled before any stream ID advances.
157+
self.store.process_replication_position(stream_name, instance_name, token)
155158

156159
if self.send_handler:
157160
await self.send_handler.process_replication_rows(stream_name, token, rows)

synapse/storage/_base.py

Lines changed: 16 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -57,7 +57,22 @@ def process_replication_rows( # noqa: B027 (no-op by design)
5757
token: int,
5858
rows: Iterable[Any],
5959
) -> None:
60-
pass
60+
"""
61+
Used by storage classes to invalidate caches based on incoming replication data. These
62+
must not update any ID generators, use `process_replication_position`.
63+
"""
64+
65+
def process_replication_position( # noqa: B027 (no-op by design)
66+
self,
67+
stream_name: str,
68+
instance_name: str,
69+
token: int,
70+
) -> None:
71+
"""
72+
Used by storage classes to advance ID generators based on incoming replication data. This
73+
is called after process_replication_rows such that caches are invalidated before any token
74+
positions advance.
75+
"""
6176

6277
def _invalidate_state_caches(
6378
self, room_id: str, members_changed: Collection[str]

synapse/storage/databases/main/account_data.py

Lines changed: 10 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -436,10 +436,7 @@ def process_replication_rows(
436436
token: int,
437437
rows: Iterable[Any],
438438
) -> None:
439-
if stream_name == TagAccountDataStream.NAME:
440-
self._account_data_id_gen.advance(instance_name, token)
441-
elif stream_name == AccountDataStream.NAME:
442-
self._account_data_id_gen.advance(instance_name, token)
439+
if stream_name == AccountDataStream.NAME:
443440
for row in rows:
444441
if not row.room_id:
445442
self.get_global_account_data_by_type_for_user.invalidate(
@@ -454,6 +451,15 @@ def process_replication_rows(
454451

455452
super().process_replication_rows(stream_name, instance_name, token, rows)
456453

454+
def process_replication_position(
455+
self, stream_name: str, instance_name: str, token: int
456+
) -> None:
457+
if stream_name == TagAccountDataStream.NAME:
458+
self._account_data_id_gen.advance(instance_name, token)
459+
elif stream_name == AccountDataStream.NAME:
460+
self._account_data_id_gen.advance(instance_name, token)
461+
super().process_replication_position(stream_name, instance_name, token)
462+
457463
async def add_account_data_to_room(
458464
self, user_id: str, room_id: str, account_data_type: str, content: JsonDict
459465
) -> int:

synapse/storage/databases/main/cache.py

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -164,9 +164,6 @@ def process_replication_rows(
164164
backfilled=True,
165165
)
166166
elif stream_name == CachesStream.NAME:
167-
if self._cache_id_gen:
168-
self._cache_id_gen.advance(instance_name, token)
169-
170167
for row in rows:
171168
if row.cache_func == CURRENT_STATE_CACHE_NAME:
172169
if row.keys is None:
@@ -182,6 +179,14 @@ def process_replication_rows(
182179

183180
super().process_replication_rows(stream_name, instance_name, token, rows)
184181

182+
def process_replication_position(
183+
self, stream_name: str, instance_name: str, token: int
184+
) -> None:
185+
if stream_name == CachesStream.NAME:
186+
if self._cache_id_gen:
187+
self._cache_id_gen.advance(instance_name, token)
188+
super().process_replication_position(stream_name, instance_name, token)
189+
185190
def _process_event_stream_row(self, token: int, row: EventsStreamRow) -> None:
186191
data = row.data
187192

synapse/storage/databases/main/deviceinbox.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -157,6 +157,13 @@ def process_replication_rows(
157157
)
158158
return super().process_replication_rows(stream_name, instance_name, token, rows)
159159

160+
def process_replication_position(
161+
self, stream_name: str, instance_name: str, token: int
162+
) -> None:
163+
if stream_name == ToDeviceStream.NAME:
164+
self._device_inbox_id_gen.advance(instance_name, token)
165+
super().process_replication_position(stream_name, instance_name, token)
166+
160167
def get_to_device_stream_token(self) -> int:
161168
return self._device_inbox_id_gen.get_current_token()
162169

synapse/storage/databases/main/devices.py

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -162,14 +162,21 @@ def process_replication_rows(
162162
self, stream_name: str, instance_name: str, token: int, rows: Iterable[Any]
163163
) -> None:
164164
if stream_name == DeviceListsStream.NAME:
165-
self._device_list_id_gen.advance(instance_name, token)
166165
self._invalidate_caches_for_devices(token, rows)
167166
elif stream_name == UserSignatureStream.NAME:
168-
self._device_list_id_gen.advance(instance_name, token)
169167
for row in rows:
170168
self._user_signature_stream_cache.entity_has_changed(row.user_id, token)
171169
return super().process_replication_rows(stream_name, instance_name, token, rows)
172170

171+
def process_replication_position(
172+
self, stream_name: str, instance_name: str, token: int
173+
) -> None:
174+
if stream_name == DeviceListsStream.NAME:
175+
self._device_list_id_gen.advance(instance_name, token)
176+
elif stream_name == UserSignatureStream.NAME:
177+
self._device_list_id_gen.advance(instance_name, token)
178+
super().process_replication_position(stream_name, instance_name, token)
179+
173180
def _invalidate_caches_for_devices(
174181
self, token: int, rows: Iterable[DeviceListsStream.DeviceListsStreamRow]
175182
) -> None:

synapse/storage/databases/main/events_worker.py

Lines changed: 10 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -388,11 +388,7 @@ def process_replication_rows(
388388
token: int,
389389
rows: Iterable[Any],
390390
) -> None:
391-
if stream_name == EventsStream.NAME:
392-
self._stream_id_gen.advance(instance_name, token)
393-
elif stream_name == BackfillStream.NAME:
394-
self._backfill_id_gen.advance(instance_name, -token)
395-
elif stream_name == UnPartialStatedEventStream.NAME:
391+
if stream_name == UnPartialStatedEventStream.NAME:
396392
for row in rows:
397393
assert isinstance(row, UnPartialStatedEventStreamRow)
398394

@@ -405,6 +401,15 @@ def process_replication_rows(
405401

406402
super().process_replication_rows(stream_name, instance_name, token, rows)
407403

404+
def process_replication_position(
405+
self, stream_name: str, instance_name: str, token: int
406+
) -> None:
407+
if stream_name == EventsStream.NAME:
408+
self._stream_id_gen.advance(instance_name, token)
409+
elif stream_name == BackfillStream.NAME:
410+
self._backfill_id_gen.advance(instance_name, -token)
411+
super().process_replication_position(stream_name, instance_name, token)
412+
408413
async def have_censored_event(self, event_id: str) -> bool:
409414
"""Check if an event has been censored, i.e. if the content of the event has been erased
410415
from the database due to a redaction.

synapse/storage/databases/main/presence.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -439,8 +439,14 @@ def process_replication_rows(
439439
rows: Iterable[Any],
440440
) -> None:
441441
if stream_name == PresenceStream.NAME:
442-
self._presence_id_gen.advance(instance_name, token)
443442
for row in rows:
444443
self.presence_stream_cache.entity_has_changed(row.user_id, token)
445444
self._get_presence_for_user.invalidate((row.user_id,))
446445
return super().process_replication_rows(stream_name, instance_name, token, rows)
446+
447+
def process_replication_position(
448+
self, stream_name: str, instance_name: str, token: int
449+
) -> None:
450+
if stream_name == PresenceStream.NAME:
451+
self._presence_id_gen.advance(instance_name, token)
452+
super().process_replication_position(stream_name, instance_name, token)

synapse/storage/databases/main/push_rule.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -154,6 +154,13 @@ def process_replication_rows(
154154
self.push_rules_stream_cache.entity_has_changed(row.user_id, token)
155155
return super().process_replication_rows(stream_name, instance_name, token, rows)
156156

157+
def process_replication_position(
158+
self, stream_name: str, instance_name: str, token: int
159+
) -> None:
160+
if stream_name == PushRulesStream.NAME:
161+
self._push_rules_stream_id_gen.advance(instance_name, token)
162+
super().process_replication_position(stream_name, instance_name, token)
163+
157164
@cached(max_entries=5000)
158165
async def get_push_rules_for_user(self, user_id: str) -> FilteredPushRules:
159166
rows = await self.db_pool.simple_select_list(

0 commit comments

Comments
 (0)