Skip to content
This repository was archived by the owner on Apr 26, 2024. It is now read-only.

Commit 78b5102

Browse files
authored
Fix up BatchingQueue (#10078)
Fixes #10068
1 parent d9f44fd commit 78b5102

File tree

3 files changed

+125
-24
lines changed

3 files changed

+125
-24
lines changed

changelog.d/10078.misc

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
Fix up `BatchingQueue` implementation.

synapse/util/batching_queue.py

Lines changed: 48 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -25,10 +25,11 @@
2525
TypeVar,
2626
)
2727

28+
from prometheus_client import Gauge
29+
2830
from twisted.internet import defer
2931

3032
from synapse.logging.context import PreserveLoggingContext, make_deferred_yieldable
31-
from synapse.metrics import LaterGauge
3233
from synapse.metrics.background_process_metrics import run_as_background_process
3334
from synapse.util import Clock
3435

@@ -38,6 +39,24 @@
3839
V = TypeVar("V")
3940
R = TypeVar("R")
4041

42+
number_queued = Gauge(
43+
"synapse_util_batching_queue_number_queued",
44+
"The number of items waiting in the queue across all keys",
45+
labelnames=("name",),
46+
)
47+
48+
number_in_flight = Gauge(
49+
"synapse_util_batching_queue_number_pending",
50+
"The number of items across all keys either being processed or waiting in a queue",
51+
labelnames=("name",),
52+
)
53+
54+
number_of_keys = Gauge(
55+
"synapse_util_batching_queue_number_of_keys",
56+
"The number of distinct keys that have items queued",
57+
labelnames=("name",),
58+
)
59+
4160

4261
class BatchingQueue(Generic[V, R]):
4362
"""A queue that batches up work, calling the provided processing function
@@ -48,10 +67,20 @@ class BatchingQueue(Generic[V, R]):
4867
called, and will keep being called until the queue has been drained (for the
4968
given key).
5069
70+
If the processing function raises an exception then the exception is proxied
71+
through to the callers waiting on that batch of work.
72+
5173
Note that the return value of `add_to_queue` will be the return value of the
5274
processing function that processed the given item. This means that the
5375
returned value will likely include data for other items that were in the
5476
batch.
77+
78+
Args:
79+
name: A name for the queue, used for logging contexts and metrics.
80+
This must be unique, otherwise the metrics will be wrong.
81+
clock: The clock to use to schedule work.
82+
process_batch_callback: The callback to to be run to process a batch of
83+
work.
5584
"""
5685

5786
def __init__(
@@ -73,19 +102,15 @@ def __init__(
73102
# The function to call with batches of values.
74103
self._process_batch_callback = process_batch_callback
75104

76-
LaterGauge(
77-
"synapse_util_batching_queue_number_queued",
78-
"The number of items waiting in the queue across all keys",
79-
labels=("name",),
80-
caller=lambda: sum(len(v) for v in self._next_values.values()),
105+
number_queued.labels(self._name).set_function(
106+
lambda: sum(len(q) for q in self._next_values.values())
81107
)
82108

83-
LaterGauge(
84-
"synapse_util_batching_queue_number_of_keys",
85-
"The number of distinct keys that have items queued",
86-
labels=("name",),
87-
caller=lambda: len(self._next_values),
88-
)
109+
number_of_keys.labels(self._name).set_function(lambda: len(self._next_values))
110+
111+
self._number_in_flight_metric = number_in_flight.labels(
112+
self._name
113+
) # type: Gauge
89114

90115
async def add_to_queue(self, value: V, key: Hashable = ()) -> R:
91116
"""Adds the value to the queue with the given key, returning the result
@@ -107,17 +132,18 @@ async def add_to_queue(self, value: V, key: Hashable = ()) -> R:
107132
if key not in self._processing_keys:
108133
run_as_background_process(self._name, self._process_queue, key)
109134

110-
return await make_deferred_yieldable(d)
135+
with self._number_in_flight_metric.track_inprogress():
136+
return await make_deferred_yieldable(d)
111137

112138
async def _process_queue(self, key: Hashable) -> None:
113139
"""A background task to repeatedly pull things off the queue for the
114140
given key and call the `self._process_batch_callback` with the values.
115141
"""
116142

117-
try:
118-
if key in self._processing_keys:
119-
return
143+
if key in self._processing_keys:
144+
return
120145

146+
try:
121147
self._processing_keys.add(key)
122148

123149
while True:
@@ -137,16 +163,16 @@ async def _process_queue(self, key: Hashable) -> None:
137163
values = [value for value, _ in next_values]
138164
results = await self._process_batch_callback(values)
139165

140-
for _, deferred in next_values:
141-
with PreserveLoggingContext():
166+
with PreserveLoggingContext():
167+
for _, deferred in next_values:
142168
deferred.callback(results)
143169

144170
except Exception as e:
145-
for _, deferred in next_values:
146-
if deferred.called:
147-
continue
171+
with PreserveLoggingContext():
172+
for _, deferred in next_values:
173+
if deferred.called:
174+
continue
148175

149-
with PreserveLoggingContext():
150176
deferred.errback(e)
151177

152178
finally:

tests/util/test_batching_queue.py

Lines changed: 76 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,12 @@
1414
from twisted.internet import defer
1515

1616
from synapse.logging.context import make_deferred_yieldable
17-
from synapse.util.batching_queue import BatchingQueue
17+
from synapse.util.batching_queue import (
18+
BatchingQueue,
19+
number_in_flight,
20+
number_of_keys,
21+
number_queued,
22+
)
1823

1924
from tests.server import get_clock
2025
from tests.unittest import TestCase
@@ -24,6 +29,14 @@ class BatchingQueueTestCase(TestCase):
2429
def setUp(self):
2530
self.clock, hs_clock = get_clock()
2631

32+
# We ensure that we remove any existing metrics for "test_queue".
33+
try:
34+
number_queued.remove("test_queue")
35+
number_of_keys.remove("test_queue")
36+
number_in_flight.remove("test_queue")
37+
except KeyError:
38+
pass
39+
2740
self._pending_calls = []
2841
self.queue = BatchingQueue("test_queue", hs_clock, self._process_queue)
2942

@@ -32,6 +45,41 @@ async def _process_queue(self, values):
3245
self._pending_calls.append((values, d))
3346
return await make_deferred_yieldable(d)
3447

48+
def _assert_metrics(self, queued, keys, in_flight):
49+
"""Assert that the metrics are correct"""
50+
51+
self.assertEqual(len(number_queued.collect()), 1)
52+
self.assertEqual(len(number_queued.collect()[0].samples), 1)
53+
self.assertEqual(
54+
number_queued.collect()[0].samples[0].labels,
55+
{"name": self.queue._name},
56+
)
57+
self.assertEqual(
58+
number_queued.collect()[0].samples[0].value,
59+
queued,
60+
"number_queued",
61+
)
62+
63+
self.assertEqual(len(number_of_keys.collect()), 1)
64+
self.assertEqual(len(number_of_keys.collect()[0].samples), 1)
65+
self.assertEqual(
66+
number_queued.collect()[0].samples[0].labels, {"name": self.queue._name}
67+
)
68+
self.assertEqual(
69+
number_of_keys.collect()[0].samples[0].value, keys, "number_of_keys"
70+
)
71+
72+
self.assertEqual(len(number_in_flight.collect()), 1)
73+
self.assertEqual(len(number_in_flight.collect()[0].samples), 1)
74+
self.assertEqual(
75+
number_queued.collect()[0].samples[0].labels, {"name": self.queue._name}
76+
)
77+
self.assertEqual(
78+
number_in_flight.collect()[0].samples[0].value,
79+
in_flight,
80+
"number_in_flight",
81+
)
82+
3583
def test_simple(self):
3684
"""Tests the basic case of calling `add_to_queue` once and having
3785
`_process_queue` return.
@@ -41,6 +89,8 @@ def test_simple(self):
4189

4290
queue_d = defer.ensureDeferred(self.queue.add_to_queue("foo"))
4391

92+
self._assert_metrics(queued=1, keys=1, in_flight=1)
93+
4494
# The queue should wait a reactor tick before calling the processing
4595
# function.
4696
self.assertFalse(self._pending_calls)
@@ -52,12 +102,15 @@ def test_simple(self):
52102
self.assertEqual(len(self._pending_calls), 1)
53103
self.assertEqual(self._pending_calls[0][0], ["foo"])
54104
self.assertFalse(queue_d.called)
105+
self._assert_metrics(queued=0, keys=0, in_flight=1)
55106

56107
# Return value of the `_process_queue` should be propagated back.
57108
self._pending_calls.pop()[1].callback("bar")
58109

59110
self.assertEqual(self.successResultOf(queue_d), "bar")
60111

112+
self._assert_metrics(queued=0, keys=0, in_flight=0)
113+
61114
def test_batching(self):
62115
"""Test that multiple calls at the same time get batched up into one
63116
call to `_process_queue`.
@@ -68,19 +121,23 @@ def test_batching(self):
68121
queue_d1 = defer.ensureDeferred(self.queue.add_to_queue("foo1"))
69122
queue_d2 = defer.ensureDeferred(self.queue.add_to_queue("foo2"))
70123

124+
self._assert_metrics(queued=2, keys=1, in_flight=2)
125+
71126
self.clock.pump([0])
72127

73128
# We should see only *one* call to `_process_queue`
74129
self.assertEqual(len(self._pending_calls), 1)
75130
self.assertEqual(self._pending_calls[0][0], ["foo1", "foo2"])
76131
self.assertFalse(queue_d1.called)
77132
self.assertFalse(queue_d2.called)
133+
self._assert_metrics(queued=0, keys=0, in_flight=2)
78134

79135
# Return value of the `_process_queue` should be propagated back to both.
80136
self._pending_calls.pop()[1].callback("bar")
81137

82138
self.assertEqual(self.successResultOf(queue_d1), "bar")
83139
self.assertEqual(self.successResultOf(queue_d2), "bar")
140+
self._assert_metrics(queued=0, keys=0, in_flight=0)
84141

85142
def test_queuing(self):
86143
"""Test that we queue up requests while a `_process_queue` is being
@@ -92,32 +149,45 @@ def test_queuing(self):
92149
queue_d1 = defer.ensureDeferred(self.queue.add_to_queue("foo1"))
93150
self.clock.pump([0])
94151

152+
self.assertEqual(len(self._pending_calls), 1)
153+
154+
# We queue up work after the process function has been called, testing
155+
# that they get correctly queued up.
95156
queue_d2 = defer.ensureDeferred(self.queue.add_to_queue("foo2"))
157+
queue_d3 = defer.ensureDeferred(self.queue.add_to_queue("foo3"))
96158

97159
# We should see only *one* call to `_process_queue`
98160
self.assertEqual(len(self._pending_calls), 1)
99161
self.assertEqual(self._pending_calls[0][0], ["foo1"])
100162
self.assertFalse(queue_d1.called)
101163
self.assertFalse(queue_d2.called)
164+
self.assertFalse(queue_d3.called)
165+
self._assert_metrics(queued=2, keys=1, in_flight=3)
102166

103167
# Return value of the `_process_queue` should be propagated back to the
104168
# first.
105169
self._pending_calls.pop()[1].callback("bar1")
106170

107171
self.assertEqual(self.successResultOf(queue_d1), "bar1")
108172
self.assertFalse(queue_d2.called)
173+
self.assertFalse(queue_d3.called)
174+
self._assert_metrics(queued=2, keys=1, in_flight=2)
109175

110176
# We should now see a second call to `_process_queue`
111177
self.clock.pump([0])
112178
self.assertEqual(len(self._pending_calls), 1)
113-
self.assertEqual(self._pending_calls[0][0], ["foo2"])
179+
self.assertEqual(self._pending_calls[0][0], ["foo2", "foo3"])
114180
self.assertFalse(queue_d2.called)
181+
self.assertFalse(queue_d3.called)
182+
self._assert_metrics(queued=0, keys=0, in_flight=2)
115183

116184
# Return value of the `_process_queue` should be propagated back to the
117185
# second.
118186
self._pending_calls.pop()[1].callback("bar2")
119187

120188
self.assertEqual(self.successResultOf(queue_d2), "bar2")
189+
self.assertEqual(self.successResultOf(queue_d3), "bar2")
190+
self._assert_metrics(queued=0, keys=0, in_flight=0)
121191

122192
def test_different_keys(self):
123193
"""Test that calls to different keys get processed in parallel."""
@@ -140,6 +210,7 @@ def test_different_keys(self):
140210
self.assertFalse(queue_d1.called)
141211
self.assertFalse(queue_d2.called)
142212
self.assertFalse(queue_d3.called)
213+
self._assert_metrics(queued=1, keys=1, in_flight=3)
143214

144215
# Return value of the `_process_queue` should be propagated back to the
145216
# first.
@@ -148,6 +219,7 @@ def test_different_keys(self):
148219
self.assertEqual(self.successResultOf(queue_d1), "bar1")
149220
self.assertFalse(queue_d2.called)
150221
self.assertFalse(queue_d3.called)
222+
self._assert_metrics(queued=1, keys=1, in_flight=2)
151223

152224
# Return value of the `_process_queue` should be propagated back to the
153225
# second.
@@ -161,9 +233,11 @@ def test_different_keys(self):
161233
self.assertEqual(len(self._pending_calls), 1)
162234
self.assertEqual(self._pending_calls[0][0], ["foo3"])
163235
self.assertFalse(queue_d3.called)
236+
self._assert_metrics(queued=0, keys=0, in_flight=1)
164237

165238
# Return value of the `_process_queue` should be propagated back to the
166239
# third deferred.
167240
self._pending_calls.pop()[1].callback("bar4")
168241

169242
self.assertEqual(self.successResultOf(queue_d3), "bar4")
243+
self._assert_metrics(queued=0, keys=0, in_flight=0)

0 commit comments

Comments
 (0)