1515# limitations under the License.
1616
1717import logging
18- from typing import List , Tuple
18+ from typing import Dict , List , Tuple
1919
2020from canonicaljson import json
2121
22- from twisted .internet import defer
23-
2422from synapse .storage ._base import db_to_json
2523from synapse .storage .databases .main .account_data import AccountDataWorkerStore
24+ from synapse .types import JsonDict
2625from synapse .util .caches .descriptors import cached
2726
2827logger = logging .getLogger (__name__ )
2928
3029
3130class TagsWorkerStore (AccountDataWorkerStore ):
3231 @cached ()
33- def get_tags_for_user (self , user_id ) :
32+ async def get_tags_for_user (self , user_id : str ) -> Dict [ str , Dict [ str , JsonDict ]] :
3433 """Get all the tags for a user.
3534
3635
3736 Args:
38- user_id(str) : The user to get the tags for.
37+ user_id: The user to get the tags for.
3938 Returns:
40- A deferred dict mapping from room_id strings to dicts mapping from
41- tag strings to tag content.
39+ A mapping from room_id strings to dicts mapping from tag strings to
40+ tag content.
4241 """
4342
44- deferred = self .db_pool .simple_select_list (
43+ rows = await self .db_pool .simple_select_list (
4544 "room_tags" , {"user_id" : user_id }, ["room_id" , "tag" , "content" ]
4645 )
4746
48- @deferred .addCallback
49- def tags_by_room (rows ):
50- tags_by_room = {}
51- for row in rows :
52- room_tags = tags_by_room .setdefault (row ["room_id" ], {})
53- room_tags [row ["tag" ]] = db_to_json (row ["content" ])
54- return tags_by_room
55-
56- return deferred
47+ tags_by_room = {}
48+ for row in rows :
49+ room_tags = tags_by_room .setdefault (row ["room_id" ], {})
50+ room_tags [row ["tag" ]] = db_to_json (row ["content" ])
51+ return tags_by_room
5752
5853 async def get_all_updated_tags (
5954 self , instance_name : str , last_id : int , current_id : int , limit : int
@@ -127,17 +122,19 @@ def get_tag_content(txn, tag_ids):
127122
128123 return results , upto_token , limited
129124
130- @defer .inlineCallbacks
131- def get_updated_tags (self , user_id , stream_id ):
125+ async def get_updated_tags (
126+ self , user_id : str , stream_id : int
127+ ) -> Dict [str , List [str ]]:
132128 """Get all the tags for the rooms where the tags have changed since the
133129 given version
134130
135131 Args:
136132 user_id(str): The user to get the tags for.
137133 stream_id(int): The earliest update to get for the user.
134+
138135 Returns:
139- A deferred dict mapping from room_id strings to lists of tag
140- strings for all the rooms that changed since the stream_id token.
136+ A mapping from room_id strings to lists of tag strings for all the
137+ rooms that changed since the stream_id token.
141138 """
142139
143140 def get_updated_tags_txn (txn ):
@@ -155,47 +152,53 @@ def get_updated_tags_txn(txn):
155152 if not changed :
156153 return {}
157154
158- room_ids = yield self .db_pool .runInteraction (
155+ room_ids = await self .db_pool .runInteraction (
159156 "get_updated_tags" , get_updated_tags_txn
160157 )
161158
162159 results = {}
163160 if room_ids :
164- tags_by_room = yield self .get_tags_for_user (user_id )
161+ tags_by_room = await self .get_tags_for_user (user_id )
165162 for room_id in room_ids :
166163 results [room_id ] = tags_by_room .get (room_id , {})
167164
168165 return results
169166
170- def get_tags_for_room (self , user_id , room_id ):
167+ async def get_tags_for_room (
168+ self , user_id : str , room_id : str
169+ ) -> Dict [str , JsonDict ]:
171170 """Get all the tags for the given room
171+
172172 Args:
173- user_id(str): The user to get tags for
174- room_id(str): The room to get tags for
173+ user_id: The user to get tags for
174+ room_id: The room to get tags for
175+
175176 Returns:
176- A deferred list of string tags.
177+ A mapping of tags to tag content .
177178 """
178- return self .db_pool .simple_select_list (
179+ rows = await self .db_pool .simple_select_list (
179180 table = "room_tags" ,
180181 keyvalues = {"user_id" : user_id , "room_id" : room_id },
181182 retcols = ("tag" , "content" ),
182183 desc = "get_tags_for_room" ,
183- ).addCallback (
184- lambda rows : {row ["tag" ]: db_to_json (row ["content" ]) for row in rows }
185184 )
185+ return {row ["tag" ]: db_to_json (row ["content" ]) for row in rows }
186186
187187
188188class TagsStore (TagsWorkerStore ):
189- @defer .inlineCallbacks
190- def add_tag_to_room (self , user_id , room_id , tag , content ):
189+ async def add_tag_to_room (
190+ self , user_id : str , room_id : str , tag : str , content : JsonDict
191+ ) -> int :
191192 """Add a tag to a room for a user.
193+
192194 Args:
193- user_id(str): The user to add a tag for.
194- room_id(str): The room to add a tag for.
195- tag(str): The tag name to add.
196- content(dict): A json object to associate with the tag.
195+ user_id: The user to add a tag for.
196+ room_id: The room to add a tag for.
197+ tag: The tag name to add.
198+ content: A json object to associate with the tag.
199+
197200 Returns:
198- A deferred that completes once the tag has been added .
201+ The next account data ID .
199202 """
200203 content_json = json .dumps (content )
201204
@@ -209,18 +212,17 @@ def add_tag_txn(txn, next_id):
209212 self ._update_revision_txn (txn , user_id , room_id , next_id )
210213
211214 with self ._account_data_id_gen .get_next () as next_id :
212- yield self .db_pool .runInteraction ("add_tag" , add_tag_txn , next_id )
215+ await self .db_pool .runInteraction ("add_tag" , add_tag_txn , next_id )
213216
214217 self .get_tags_for_user .invalidate ((user_id ,))
215218
216- result = self ._account_data_id_gen .get_current_token ()
217- return result
219+ return self ._account_data_id_gen .get_current_token ()
218220
219- @defer .inlineCallbacks
220- def remove_tag_from_room (self , user_id , room_id , tag ):
221+ async def remove_tag_from_room (self , user_id : str , room_id : str , tag : str ) -> int :
221222 """Remove a tag from a room for a user.
223+
222224 Returns:
223- A deferred that completes once the tag has been removed
225+ The next account data ID.
224226 """
225227
226228 def remove_tag_txn (txn , next_id ):
@@ -232,21 +234,22 @@ def remove_tag_txn(txn, next_id):
232234 self ._update_revision_txn (txn , user_id , room_id , next_id )
233235
234236 with self ._account_data_id_gen .get_next () as next_id :
235- yield self .db_pool .runInteraction ("remove_tag" , remove_tag_txn , next_id )
237+ await self .db_pool .runInteraction ("remove_tag" , remove_tag_txn , next_id )
236238
237239 self .get_tags_for_user .invalidate ((user_id ,))
238240
239- result = self ._account_data_id_gen .get_current_token ()
240- return result
241+ return self ._account_data_id_gen .get_current_token ()
241242
242- def _update_revision_txn (self , txn , user_id , room_id , next_id ):
243+ def _update_revision_txn (
244+ self , txn , user_id : str , room_id : str , next_id : int
245+ ) -> None :
243246 """Update the latest revision of the tags for the given user and room.
244247
245248 Args:
246249 txn: The database cursor
247- user_id(str) : The ID of the user.
248- room_id(str) : The ID of the room.
249- next_id(int) : The the revision to advance to.
250+ user_id: The ID of the user.
251+ room_id: The ID of the room.
252+ next_id: The the revision to advance to.
250253 """
251254
252255 txn .call_after (
0 commit comments