50
50
Dict ,
51
51
Iterable ,
52
52
List ,
53
+ Mapping ,
53
54
Optional ,
54
55
Protocol ,
55
56
Set ,
80
81
from synapse .storage .engines import BaseDatabaseEngine , PostgresEngine , Sqlite3Engine
81
82
from synapse .storage .util .id_generators import MultiWriterIdGenerator
82
83
from synapse .types import PersistedEventPosition , RoomStreamToken , StrCollection
83
- from synapse .util .caches .descriptors import cached
84
+ from synapse .util .caches .descriptors import cached , cachedList
84
85
from synapse .util .caches .stream_change_cache import StreamChangeCache
85
86
from synapse .util .cancellation import cancellable
86
87
from synapse .util .iterutils import batch_iter
@@ -1381,40 +1382,85 @@ async def bulk_get_last_event_pos_in_room_before_stream_ordering(
1381
1382
rooms
1382
1383
"""
1383
1384
1385
+ # First we just get the latest positions for the room, as the vast
1386
+ # majority of them will be before the given end token anyway. By doing
1387
+ # this we can cache most rooms.
1388
+ uncapped_results = await self ._bulk_get_max_event_pos (room_ids )
1389
+
1390
+ # Check that the stream position for the rooms are from before the
1391
+ # minimum position of the token. If not then we need to fetch more
1392
+ # rows.
1393
+ results : Dict [str , int ] = {}
1394
+ recheck_rooms : Set [str ] = set ()
1384
1395
min_token = end_token .stream
1385
- max_token = end_token .get_max_stream_pos ()
1396
+ for room_id , stream in uncapped_results .items ():
1397
+ if stream <= min_token :
1398
+ results [room_id ] = stream
1399
+ else :
1400
+ recheck_rooms .add (room_id )
1401
+
1402
+ if not recheck_rooms :
1403
+ return results
1404
+
1405
+ # There shouldn't be many rooms that we need to recheck, so we do them
1406
+ # one-by-one.
1407
+ for room_id in recheck_rooms :
1408
+ result = await self .get_last_event_pos_in_room_before_stream_ordering (
1409
+ room_id , end_token
1410
+ )
1411
+ if result is not None :
1412
+ results [room_id ] = result [1 ].stream
1413
+
1414
+ return results
1415
+
1416
+ @cached ()
1417
+ async def _get_max_event_pos (self , room_id : str ) -> int :
1418
+ raise NotImplementedError ()
1419
+
1420
+ @cachedList (cached_method_name = "_get_max_event_pos" , list_name = "room_ids" )
1421
+ async def _bulk_get_max_event_pos (
1422
+ self , room_ids : StrCollection
1423
+ ) -> Mapping [str , int ]:
1424
+ """Fetch the max position of a persisted event in the room."""
1425
+
1426
+ # We need to be careful not to return positions ahead of the current
1427
+ # positions, so we get the current token now and cap our queries to it.
1428
+ now_token = self .get_room_max_token ()
1429
+ max_pos = now_token .get_max_stream_pos ()
1430
+
1386
1431
results : Dict [str , int ] = {}
1387
1432
1388
1433
# First, we check for the rooms in the stream change cache to see if we
1389
1434
# can just use the latest position from it.
1390
1435
missing_room_ids : Set [str ] = set ()
1391
1436
for room_id in room_ids :
1392
1437
stream_pos = self ._events_stream_cache .get_max_pos_of_last_change (room_id )
1393
- if stream_pos and stream_pos <= min_token :
1438
+ if stream_pos is not None :
1394
1439
results [room_id ] = stream_pos
1395
1440
else :
1396
1441
missing_room_ids .add (room_id )
1397
1442
1443
+ if not missing_room_ids :
1444
+ return results
1445
+
1398
1446
# Next, we query the stream position from the DB. At first we fetch all
1399
1447
# positions less than the *max* stream pos in the token, then filter
1400
1448
# them down. We do this as a) this is a cheaper query, and b) the vast
1401
1449
# majority of rooms will have a latest token from before the min stream
1402
1450
# pos.
1403
1451
1404
- def bulk_get_last_event_pos_txn (
1405
- txn : LoggingTransaction , batch_room_ids : StrCollection
1452
+ def bulk_get_max_event_pos_txn (
1453
+ txn : LoggingTransaction , batched_room_ids : StrCollection
1406
1454
) -> Dict [str , int ]:
1407
- # This query fetches the latest stream position in the rooms before
1408
- # the given max position.
1409
1455
clause , args = make_in_list_sql_clause (
1410
- self .database_engine , "room_id" , batch_room_ids
1456
+ self .database_engine , "room_id" , batched_room_ids
1411
1457
)
1412
1458
sql = f"""
1413
1459
SELECT room_id, (
1414
1460
SELECT stream_ordering FROM events AS e
1415
1461
LEFT JOIN rejections USING (event_id)
1416
1462
WHERE e.room_id = r.room_id
1417
- AND stream_ordering <= ?
1463
+ AND e. stream_ordering <= ?
1418
1464
AND NOT outlier
1419
1465
AND rejection_reason IS NULL
1420
1466
ORDER BY stream_ordering DESC
@@ -1423,72 +1469,29 @@ def bulk_get_last_event_pos_txn(
1423
1469
FROM rooms AS r
1424
1470
WHERE { clause }
1425
1471
"""
1426
- txn .execute (sql , [max_token ] + args )
1472
+ txn .execute (sql , [max_pos ] + args )
1427
1473
return {row [0 ]: row [1 ] for row in txn }
1428
1474
1429
1475
recheck_rooms : Set [str ] = set ()
1430
- for batched in batch_iter (missing_room_ids , 1000 ):
1431
- result = await self .db_pool .runInteraction (
1432
- "bulk_get_last_event_pos_in_room_before_stream_ordering" ,
1433
- bulk_get_last_event_pos_txn ,
1434
- batched ,
1476
+ for batched in batch_iter (room_ids , 1000 ):
1477
+ batch_results = await self .db_pool .runInteraction (
1478
+ "_bulk_get_max_event_pos" , bulk_get_max_event_pos_txn , batched
1435
1479
)
1436
-
1437
- # Check that the stream position for the rooms are from before the
1438
- # minimum position of the token. If not then we need to fetch more
1439
- # rows.
1440
- for room_id , stream in result .items ():
1441
- if stream <= min_token :
1442
- results [room_id ] = stream
1480
+ for room_id , stream_ordering in batch_results .items ():
1481
+ if stream_ordering <= now_token .stream :
1482
+ results .update (batch_results )
1443
1483
else :
1444
1484
recheck_rooms .add (room_id )
1445
1485
1446
- if not recheck_rooms :
1447
- return results
1448
-
1449
- # For the remaining rooms we need to fetch all rows between the min and
1450
- # max stream positions in the end token, and filter out the rows that
1451
- # are after the end token.
1452
- #
1453
- # This query should be fast as the range between the min and max should
1454
- # be small.
1455
-
1456
- def bulk_get_last_event_pos_recheck_txn (
1457
- txn : LoggingTransaction , batch_room_ids : StrCollection
1458
- ) -> Dict [str , int ]:
1459
- clause , args = make_in_list_sql_clause (
1460
- self .database_engine , "room_id" , batch_room_ids
1461
- )
1462
- sql = f"""
1463
- SELECT room_id, instance_name, stream_ordering
1464
- FROM events
1465
- WHERE ? < stream_ordering AND stream_ordering <= ?
1466
- AND NOT outlier
1467
- AND rejection_reason IS NULL
1468
- AND { clause }
1469
- ORDER BY stream_ordering ASC
1470
- """
1471
- txn .execute (sql , [min_token , max_token ] + args )
1472
-
1473
- # We take the max stream ordering that is less than the token. Since
1474
- # we ordered by stream ordering we just need to iterate through and
1475
- # take the last matching stream ordering.
1476
- txn_results : Dict [str , int ] = {}
1477
- for row in txn :
1478
- room_id = row [0 ]
1479
- event_pos = PersistedEventPosition (row [1 ], row [2 ])
1480
- if not event_pos .persisted_after (end_token ):
1481
- txn_results [room_id ] = event_pos .stream
1482
-
1483
- return txn_results
1484
-
1485
- for batched in batch_iter (recheck_rooms , 1000 ):
1486
- recheck_result = await self .db_pool .runInteraction (
1487
- "bulk_get_last_event_pos_in_room_before_stream_ordering_recheck" ,
1488
- bulk_get_last_event_pos_recheck_txn ,
1489
- batched ,
1486
+ # We now need to handle rooms where the above query returned a stream
1487
+ # position that was potentially too new. This should happen very rarely
1488
+ # so we just query the rooms one-by-one
1489
+ for room_id in recheck_rooms :
1490
+ result = await self .get_last_event_pos_in_room_before_stream_ordering (
1491
+ room_id , now_token
1490
1492
)
1491
- results .update (recheck_result )
1493
+ if result is not None :
1494
+ results [room_id ] = result [1 ].stream
1492
1495
1493
1496
return results
1494
1497
0 commit comments