33import yaml
44
55from synapse .config import ConfigError
6- from synapse .config .api import ApiConfig , StateKeyFilter
7-
8- DEFAULT_PREJOIN_STATE = {
9- "m.room.join_rules" : StateKeyFilter .only ("" ),
10- "m.room.canonical_alias" : StateKeyFilter .only ("" ),
11- "m.room.avatar" : StateKeyFilter .only ("" ),
12- "m.room.encryption" : StateKeyFilter .only ("" ),
13- "m.room.name" : StateKeyFilter .only ("" ),
14- "m.room.create" : StateKeyFilter .only ("" ),
15- "m.room.topic" : StateKeyFilter .only ("" ),
6+ from synapse .config .api import ApiConfig
7+ from synapse .types .state import StateFilter
8+
9+ DEFAULT_PREJOIN_STATE_PAIRS = {
10+ ("m.room.join_rules" , "" ),
11+ ("m.room.canonical_alias" , "" ),
12+ ("m.room.avatar" , "" ),
13+ ("m.room.encryption" , "" ),
14+ ("m.room.name" , "" ),
15+ ("m.room.create" , "" ),
16+ ("m.room.topic" , "" ),
1617}
1718
1819
1920class TestRoomPrejoinState (StdlibTestCase ):
20- def test_state_key_filter (self ) -> None :
21- """Sanity check the StateKeyFilter class."""
22- s = StateKeyFilter .only ("foo" )
23- self .assertIn ("foo" , s )
24- self .assertNotIn ("bar" , s )
25- self .assertNotIn ("baz" , s )
26- s .add ("bar" )
27- self .assertIn ("foo" , s )
28- self .assertIn ("bar" , s )
29- self .assertNotIn ("baz" , s )
30-
31- s = StateKeyFilter .any ()
32- self .assertIn ("foo" , s )
33- self .assertIn ("bar" , s )
34- self .assertIn ("baz" , s )
35- s .add ("bar" )
36- self .assertIn ("foo" , s )
37- self .assertIn ("bar" , s )
38- self .assertIn ("baz" , s )
39-
4021 def read_config (self , source : str ) -> ApiConfig :
4122 config = ApiConfig ()
4223 config .read_config (yaml .safe_load (source ))
4324 return config
4425
4526 def test_no_prejoin_state (self ) -> None :
4627 config = self .read_config ("foo: bar" )
47- self .assertEqual (config .room_prejoin_state , DEFAULT_PREJOIN_STATE )
28+ self .assertFalse (config .room_prejoin_state .has_wildcards ())
29+ self .assertEqual (
30+ set (config .room_prejoin_state .concrete_types ()), DEFAULT_PREJOIN_STATE_PAIRS
31+ )
4832
4933 def test_disable_default_event_types (self ) -> None :
5034 config = self .read_config (
@@ -53,7 +37,7 @@ def test_disable_default_event_types(self) -> None:
5337 disable_default_event_types: true
5438 """
5539 )
56- self .assertEqual (config .room_prejoin_state , {} )
40+ self .assertEqual (config .room_prejoin_state , StateFilter . none () )
5741
5842 def test_event_without_state_key (self ) -> None :
5943 config = self .read_config (
@@ -64,7 +48,8 @@ def test_event_without_state_key(self) -> None:
6448 - foo
6549 """
6650 )
67- self .assertEqual (config .room_prejoin_state , {"foo" : StateKeyFilter .any ()})
51+ self .assertEqual (config .room_prejoin_state .wildcard_types (), ["foo" ])
52+ self .assertEqual (config .room_prejoin_state .concrete_types (), [])
6853
6954 def test_event_with_specific_state_key (self ) -> None :
7055 config = self .read_config (
@@ -75,7 +60,11 @@ def test_event_with_specific_state_key(self) -> None:
7560 - [foo, bar]
7661 """
7762 )
78- self .assertEqual (config .room_prejoin_state , {"foo" : StateKeyFilter .only ("bar" )})
63+ self .assertFalse (config .room_prejoin_state .has_wildcards ())
64+ self .assertEqual (
65+ set (config .room_prejoin_state .concrete_types ()),
66+ {("foo" , "bar" )},
67+ )
7968
8069 def test_repeated_event_with_specific_state_key (self ) -> None :
8170 config = self .read_config (
@@ -87,8 +76,10 @@ def test_repeated_event_with_specific_state_key(self) -> None:
8776 - [foo, baz]
8877 """
8978 )
79+ self .assertFalse (config .room_prejoin_state .has_wildcards ())
9080 self .assertEqual (
91- config .room_prejoin_state , {"foo" : StateKeyFilter ({"bar" , "baz" })}
81+ set (config .room_prejoin_state .concrete_types ()),
82+ {("foo" , "bar" ), ("foo" , "baz" )},
9283 )
9384
9485 def test_no_specific_state_key_overrides_specific_state_key (self ) -> None :
@@ -101,7 +92,8 @@ def test_no_specific_state_key_overrides_specific_state_key(self) -> None:
10192 - foo
10293 """
10394 )
104- self .assertEqual (config .room_prejoin_state , {"foo" : StateKeyFilter .any ()})
95+ self .assertEqual (config .room_prejoin_state .wildcard_types (), ["foo" ])
96+ self .assertEqual (config .room_prejoin_state .concrete_types (), [])
10597
10698 config = self .read_config (
10799 """
@@ -112,7 +104,8 @@ def test_no_specific_state_key_overrides_specific_state_key(self) -> None:
112104 - [foo, bar]
113105 """
114106 )
115- self .assertEqual (config .room_prejoin_state , {"foo" : StateKeyFilter .any ()})
107+ self .assertEqual (config .room_prejoin_state .wildcard_types (), ["foo" ])
108+ self .assertEqual (config .room_prejoin_state .concrete_types (), [])
116109
117110 def test_bad_event_type_entry_raises (self ) -> None :
118111 with self .assertRaises (ConfigError ):
0 commit comments