14
14
from twisted .internet import defer
15
15
16
16
from synapse .logging .context import make_deferred_yieldable
17
- from synapse .util .batching_queue import BatchingQueue
17
+ from synapse .util .batching_queue import (
18
+ BatchingQueue ,
19
+ number_in_flight ,
20
+ number_of_keys ,
21
+ number_queued ,
22
+ )
18
23
19
24
from tests .server import get_clock
20
25
from tests .unittest import TestCase
@@ -24,6 +29,14 @@ class BatchingQueueTestCase(TestCase):
24
29
def setUp (self ):
25
30
self .clock , hs_clock = get_clock ()
26
31
32
+ # We ensure that we remove any existing metrics for "test_queue".
33
+ try :
34
+ number_queued .remove ("test_queue" )
35
+ number_of_keys .remove ("test_queue" )
36
+ number_in_flight .remove ("test_queue" )
37
+ except KeyError :
38
+ pass
39
+
27
40
self ._pending_calls = []
28
41
self .queue = BatchingQueue ("test_queue" , hs_clock , self ._process_queue )
29
42
@@ -32,6 +45,41 @@ async def _process_queue(self, values):
32
45
self ._pending_calls .append ((values , d ))
33
46
return await make_deferred_yieldable (d )
34
47
48
+ def _assert_metrics (self , queued , keys , in_flight ):
49
+ """Assert that the metrics are correct"""
50
+
51
+ self .assertEqual (len (number_queued .collect ()), 1 )
52
+ self .assertEqual (len (number_queued .collect ()[0 ].samples ), 1 )
53
+ self .assertEqual (
54
+ number_queued .collect ()[0 ].samples [0 ].labels ,
55
+ {"name" : self .queue ._name },
56
+ )
57
+ self .assertEqual (
58
+ number_queued .collect ()[0 ].samples [0 ].value ,
59
+ queued ,
60
+ "number_queued" ,
61
+ )
62
+
63
+ self .assertEqual (len (number_of_keys .collect ()), 1 )
64
+ self .assertEqual (len (number_of_keys .collect ()[0 ].samples ), 1 )
65
+ self .assertEqual (
66
+ number_queued .collect ()[0 ].samples [0 ].labels , {"name" : self .queue ._name }
67
+ )
68
+ self .assertEqual (
69
+ number_of_keys .collect ()[0 ].samples [0 ].value , keys , "number_of_keys"
70
+ )
71
+
72
+ self .assertEqual (len (number_in_flight .collect ()), 1 )
73
+ self .assertEqual (len (number_in_flight .collect ()[0 ].samples ), 1 )
74
+ self .assertEqual (
75
+ number_queued .collect ()[0 ].samples [0 ].labels , {"name" : self .queue ._name }
76
+ )
77
+ self .assertEqual (
78
+ number_in_flight .collect ()[0 ].samples [0 ].value ,
79
+ in_flight ,
80
+ "number_in_flight" ,
81
+ )
82
+
35
83
def test_simple (self ):
36
84
"""Tests the basic case of calling `add_to_queue` once and having
37
85
`_process_queue` return.
@@ -41,6 +89,8 @@ def test_simple(self):
41
89
42
90
queue_d = defer .ensureDeferred (self .queue .add_to_queue ("foo" ))
43
91
92
+ self ._assert_metrics (queued = 1 , keys = 1 , in_flight = 1 )
93
+
44
94
# The queue should wait a reactor tick before calling the processing
45
95
# function.
46
96
self .assertFalse (self ._pending_calls )
@@ -52,12 +102,15 @@ def test_simple(self):
52
102
self .assertEqual (len (self ._pending_calls ), 1 )
53
103
self .assertEqual (self ._pending_calls [0 ][0 ], ["foo" ])
54
104
self .assertFalse (queue_d .called )
105
+ self ._assert_metrics (queued = 0 , keys = 0 , in_flight = 1 )
55
106
56
107
# Return value of the `_process_queue` should be propagated back.
57
108
self ._pending_calls .pop ()[1 ].callback ("bar" )
58
109
59
110
self .assertEqual (self .successResultOf (queue_d ), "bar" )
60
111
112
+ self ._assert_metrics (queued = 0 , keys = 0 , in_flight = 0 )
113
+
61
114
def test_batching (self ):
62
115
"""Test that multiple calls at the same time get batched up into one
63
116
call to `_process_queue`.
@@ -68,19 +121,23 @@ def test_batching(self):
68
121
queue_d1 = defer .ensureDeferred (self .queue .add_to_queue ("foo1" ))
69
122
queue_d2 = defer .ensureDeferred (self .queue .add_to_queue ("foo2" ))
70
123
124
+ self ._assert_metrics (queued = 2 , keys = 1 , in_flight = 2 )
125
+
71
126
self .clock .pump ([0 ])
72
127
73
128
# We should see only *one* call to `_process_queue`
74
129
self .assertEqual (len (self ._pending_calls ), 1 )
75
130
self .assertEqual (self ._pending_calls [0 ][0 ], ["foo1" , "foo2" ])
76
131
self .assertFalse (queue_d1 .called )
77
132
self .assertFalse (queue_d2 .called )
133
+ self ._assert_metrics (queued = 0 , keys = 0 , in_flight = 2 )
78
134
79
135
# Return value of the `_process_queue` should be propagated back to both.
80
136
self ._pending_calls .pop ()[1 ].callback ("bar" )
81
137
82
138
self .assertEqual (self .successResultOf (queue_d1 ), "bar" )
83
139
self .assertEqual (self .successResultOf (queue_d2 ), "bar" )
140
+ self ._assert_metrics (queued = 0 , keys = 0 , in_flight = 0 )
84
141
85
142
def test_queuing (self ):
86
143
"""Test that we queue up requests while a `_process_queue` is being
@@ -92,32 +149,45 @@ def test_queuing(self):
92
149
queue_d1 = defer .ensureDeferred (self .queue .add_to_queue ("foo1" ))
93
150
self .clock .pump ([0 ])
94
151
152
+ self .assertEqual (len (self ._pending_calls ), 1 )
153
+
154
+ # We queue up work after the process function has been called, testing
155
+ # that they get correctly queued up.
95
156
queue_d2 = defer .ensureDeferred (self .queue .add_to_queue ("foo2" ))
157
+ queue_d3 = defer .ensureDeferred (self .queue .add_to_queue ("foo3" ))
96
158
97
159
# We should see only *one* call to `_process_queue`
98
160
self .assertEqual (len (self ._pending_calls ), 1 )
99
161
self .assertEqual (self ._pending_calls [0 ][0 ], ["foo1" ])
100
162
self .assertFalse (queue_d1 .called )
101
163
self .assertFalse (queue_d2 .called )
164
+ self .assertFalse (queue_d3 .called )
165
+ self ._assert_metrics (queued = 2 , keys = 1 , in_flight = 3 )
102
166
103
167
# Return value of the `_process_queue` should be propagated back to the
104
168
# first.
105
169
self ._pending_calls .pop ()[1 ].callback ("bar1" )
106
170
107
171
self .assertEqual (self .successResultOf (queue_d1 ), "bar1" )
108
172
self .assertFalse (queue_d2 .called )
173
+ self .assertFalse (queue_d3 .called )
174
+ self ._assert_metrics (queued = 2 , keys = 1 , in_flight = 2 )
109
175
110
176
# We should now see a second call to `_process_queue`
111
177
self .clock .pump ([0 ])
112
178
self .assertEqual (len (self ._pending_calls ), 1 )
113
- self .assertEqual (self ._pending_calls [0 ][0 ], ["foo2" ])
179
+ self .assertEqual (self ._pending_calls [0 ][0 ], ["foo2" , "foo3" ])
114
180
self .assertFalse (queue_d2 .called )
181
+ self .assertFalse (queue_d3 .called )
182
+ self ._assert_metrics (queued = 0 , keys = 0 , in_flight = 2 )
115
183
116
184
# Return value of the `_process_queue` should be propagated back to the
117
185
# second.
118
186
self ._pending_calls .pop ()[1 ].callback ("bar2" )
119
187
120
188
self .assertEqual (self .successResultOf (queue_d2 ), "bar2" )
189
+ self .assertEqual (self .successResultOf (queue_d3 ), "bar2" )
190
+ self ._assert_metrics (queued = 0 , keys = 0 , in_flight = 0 )
121
191
122
192
def test_different_keys (self ):
123
193
"""Test that calls to different keys get processed in parallel."""
@@ -140,6 +210,7 @@ def test_different_keys(self):
140
210
self .assertFalse (queue_d1 .called )
141
211
self .assertFalse (queue_d2 .called )
142
212
self .assertFalse (queue_d3 .called )
213
+ self ._assert_metrics (queued = 1 , keys = 1 , in_flight = 3 )
143
214
144
215
# Return value of the `_process_queue` should be propagated back to the
145
216
# first.
@@ -148,6 +219,7 @@ def test_different_keys(self):
148
219
self .assertEqual (self .successResultOf (queue_d1 ), "bar1" )
149
220
self .assertFalse (queue_d2 .called )
150
221
self .assertFalse (queue_d3 .called )
222
+ self ._assert_metrics (queued = 1 , keys = 1 , in_flight = 2 )
151
223
152
224
# Return value of the `_process_queue` should be propagated back to the
153
225
# second.
@@ -161,9 +233,11 @@ def test_different_keys(self):
161
233
self .assertEqual (len (self ._pending_calls ), 1 )
162
234
self .assertEqual (self ._pending_calls [0 ][0 ], ["foo3" ])
163
235
self .assertFalse (queue_d3 .called )
236
+ self ._assert_metrics (queued = 0 , keys = 0 , in_flight = 1 )
164
237
165
238
# Return value of the `_process_queue` should be propagated back to the
166
239
# third deferred.
167
240
self ._pending_calls .pop ()[1 ].callback ("bar4" )
168
241
169
242
self .assertEqual (self .successResultOf (queue_d3 ), "bar4" )
243
+ self ._assert_metrics (queued = 0 , keys = 0 , in_flight = 0 )
0 commit comments