12
12
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
13
# See the License for the specific language governing permissions and
14
14
# limitations under the License.
15
- import collections .abc
16
15
import logging
17
- from typing import TYPE_CHECKING , Collection , Iterable , Optional , Set , Tuple
16
+ from typing import TYPE_CHECKING , Dict , Iterable , Optional , Set , Tuple
17
+
18
+ from frozendict import frozendict
18
19
19
20
from synapse .api .constants import EventTypes , Membership
20
21
from synapse .api .errors import NotFoundError , UnsupportedRoomVersionError
29
30
from synapse .storage .databases .main .events_worker import EventsWorkerStore
30
31
from synapse .storage .databases .main .roommember import RoomMemberWorkerStore
31
32
from synapse .storage .state import StateFilter
32
- from synapse .types import JsonDict , StateMap
33
+ from synapse .types import JsonDict , JsonMapping , StateMap
33
34
from synapse .util .caches import intern_string
34
35
from synapse .util .caches .descriptors import cached , cachedList
35
36
@@ -132,7 +133,7 @@ def get_room_version_id_txn(self, txn: LoggingTransaction, room_id: str) -> str:
132
133
133
134
return room_version
134
135
135
- async def get_room_predecessor (self , room_id : str ) -> Optional [dict ]:
136
+ async def get_room_predecessor (self , room_id : str ) -> Optional [JsonMapping ]:
136
137
"""Get the predecessor of an upgraded room if it exists.
137
138
Otherwise return None.
138
139
@@ -158,9 +159,10 @@ async def get_room_predecessor(self, room_id: str) -> Optional[dict]:
158
159
predecessor = create_event .content .get ("predecessor" , None )
159
160
160
161
# Ensure the key is a dictionary
161
- if not isinstance (predecessor , collections . abc . Mapping ):
162
+ if not isinstance (predecessor , ( dict , frozendict ) ):
162
163
return None
163
164
165
+ # The keys must be strings since the data is JSON.
164
166
return predecessor
165
167
166
168
async def get_create_event_for_room (self , room_id : str ) -> EventBase :
@@ -306,7 +308,9 @@ async def _get_state_group_for_event(self, event_id: str) -> Optional[int]:
306
308
list_name = "event_ids" ,
307
309
num_args = 1 ,
308
310
)
309
- async def _get_state_group_for_events (self , event_ids : Collection [str ]) -> JsonDict :
311
+ async def _get_state_group_for_events (
312
+ self , event_ids : Iterable [str ]
313
+ ) -> Dict [str , int ]:
310
314
"""Returns mapping event_id -> state_group"""
311
315
rows = await self .db_pool .simple_select_many_batch (
312
316
table = "event_to_state_groups" ,
@@ -521,7 +525,7 @@ def _background_remove_left_rooms_txn(
521
525
)
522
526
523
527
for user_id in potentially_left_users - joined_users :
524
- await self .mark_remote_user_device_list_as_unsubscribed (user_id )
528
+ await self .mark_remote_user_device_list_as_unsubscribed (user_id ) # type: ignore[attr-defined]
525
529
526
530
return batch_size
527
531
0 commit comments