Skip to content

Commit 23a74cf

Browse files
Added Batch signal subscriber. (#9)
* Added Batch signal subscriber. * Import sorting
1 parent 6b9e207 commit 23a74cf

File tree

4 files changed

+115
-6
lines changed

4 files changed

+115
-6
lines changed

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
[tool.poetry]
22
name = "asyncio-signal-bus"
3-
version = "1.2.1"
3+
version = "1.3.0"
44
description = "Internal application publisher/subscriber bus using asyncio queues."
55
authors = ["DustinMoriarty <dustin.moriarty@protonmail.com>"]
66
readme = "README.md"

src/asyncio_signal_bus/signal_bus.py

Lines changed: 47 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
from asyncio_signal_bus.injector import Injector
88
from asyncio_signal_bus.publisher import SignalPublisher
99
from asyncio_signal_bus.queue_getter import QueueGetter
10-
from asyncio_signal_bus.subscriber import SignalSubscriber
10+
from asyncio_signal_bus.subscriber import BatchSignalSubscriber, SignalSubscriber
1111
from asyncio_signal_bus.types import R, S
1212

1313
LOGGER = getLogger(__name__)
@@ -89,7 +89,52 @@ def subscriber(
8989
self._queues.get(topic_name).append(queue)
9090

9191
def _wrapper(f: Callable[[S], Awaitable[R]]) -> SignalSubscriber[S, R]:
92-
s = SignalSubscriber(error_handler(f), queue)
92+
s = SignalSubscriber(
93+
error_handler(f), queue, shutdown_timeout=shutdown_timeout
94+
)
95+
LOGGER.debug(f"Registering subscriber to topic {topic_name}")
96+
self._subscribers.append(s)
97+
return s
98+
99+
return _wrapper
100+
101+
def batch_subscriber(
102+
self,
103+
topic_name="default",
104+
error_handler: Type[SubscriberErrorHandler] = SubscriberErrorHandler[S, R],
105+
shutdown_timeout: Optional[SupportsFloat] = 120,
106+
max_items: int = 10,
107+
period_seconds: int = 10,
108+
):
109+
"""
110+
A subscriber that consumes batches of events. The subscriber will wait no longer
111+
than the period_seconds between aggregations. Batches will not exceed batch
112+
seconds in size.
113+
:param topic_name: The name of the topic used to link one or more subscribers
114+
with one or more publishers.
115+
:param error_handler: An error handler used to handle errors within the callable.
116+
Error handling should usually terminate at the subscriber, with the
117+
subscriber catching all exceptions. Any unhandled errors will block the
118+
shutdown of the bus when the bus exits context or the stop method is used.
119+
:param shutdown_timeout: If the subscriber takes longer than this time during
120+
shutdown, then the task is killed and an error is raised. If you do not
121+
want the task timeout to be limited, then set this value to None.
122+
:param max_items: The maximum amount of time for the batch.
123+
:param period_seconds: The maximum amount of timem to wait between batches.
124+
:return: Wrapped callable
125+
"""
126+
self._queues.setdefault(topic_name, [])
127+
queue = Queue()
128+
self._queues.get(topic_name).append(queue)
129+
130+
def _wrapper(f: Callable[[S], Awaitable[R]]) -> SignalSubscriber[S, R]:
131+
s = BatchSignalSubscriber(
132+
error_handler(f),
133+
queue,
134+
shutdown_timeout=shutdown_timeout,
135+
max_items=max_items,
136+
period_seconds=period_seconds,
137+
)
93138
LOGGER.debug(f"Registering subscriber to topic {topic_name}")
94139
self._subscribers.append(s)
95140
return s

src/asyncio_signal_bus/subscriber.py

Lines changed: 41 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,8 @@
11
import asyncio
22
from asyncio import Queue, Task
33
from logging import getLogger
4-
from typing import Awaitable, Callable, Generic, Optional, SupportsFloat
4+
from time import time
5+
from typing import Awaitable, Callable, Generic, Optional, SupportsFloat, SupportsInt
56

67
from asyncio_signal_bus.exception import SignalBusShutdownError
78
from asyncio_signal_bus.types import R, S
@@ -71,3 +72,42 @@ async def __aexit__(self, exc_type, exc_val, exc_tb):
7172

7273
async def __call__(self, signal: S, *args, **kwargs) -> R:
7374
return await self._f(signal, *args, **kwargs)
75+
76+
77+
class BatchSignalSubscriber(SignalSubscriber):
78+
def __init__(
79+
self,
80+
f: Callable[[S], Awaitable[R]],
81+
queue: Queue,
82+
shutdown_timeout: SupportsFloat = 120,
83+
max_items: SupportsInt = 10,
84+
period_seconds: SupportsFloat = 10,
85+
):
86+
super().__init__(f, queue, shutdown_timeout)
87+
self.max_items = int(max_items)
88+
self.period_seconds = float(period_seconds)
89+
90+
async def _batch_task_wrapper(self, coroutine, n_items: int):
91+
await coroutine
92+
for i in range(n_items):
93+
self._queue.task_done()
94+
95+
async def _listen(self):
96+
LOGGER.debug("Started listening.")
97+
batch = []
98+
ts = time()
99+
while True:
100+
if len(batch) < self.max_items and not self._queue.empty():
101+
signal = self._queue.get_nowait()
102+
batch.append(signal)
103+
if batch and (
104+
len(batch) >= self.max_items or (time() - ts) > self.period_seconds
105+
):
106+
asyncio.create_task(self._batch_task_wrapper(self(batch), len(batch)))
107+
ts = time()
108+
batch = []
109+
# This deadlocks unless there is some sleep. It does not need to be long.
110+
# My guess is that under the hood we just need to allow at least one clock
111+
# cycle for something else to happen so that the event loop can perform
112+
# other operations.
113+
await asyncio.sleep(1e-10)

tests/test_subscriber.py

Lines changed: 26 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,11 @@
11
import asyncio
22
from asyncio import Queue
3-
from unittest.mock import Mock
3+
from unittest.mock import Mock, call
44

55
import pytest
66

77
from asyncio_signal_bus.exception import SignalBusShutdownError
8-
from asyncio_signal_bus.subscriber import SignalSubscriber
8+
from asyncio_signal_bus.subscriber import BatchSignalSubscriber, SignalSubscriber
99

1010

1111
@pytest.mark.asyncio
@@ -36,3 +36,27 @@ async def foo_subscriber(signal: str):
3636
with pytest.raises(SignalBusShutdownError):
3737
async with signal_subscriber:
3838
await subscriber_queue.put("a")
39+
40+
41+
@pytest.mark.asyncio
42+
async def test_batch_subscriber():
43+
subscriber_queue = Queue()
44+
45+
target_mock = Mock()
46+
47+
async def foo_subscriber(signal: str):
48+
target_mock(signal)
49+
50+
signal_subscriber = BatchSignalSubscriber(
51+
foo_subscriber,
52+
subscriber_queue,
53+
max_items=3,
54+
period_seconds=0.1,
55+
shutdown_timeout=0.3,
56+
)
57+
async with signal_subscriber:
58+
for i in range(5):
59+
await subscriber_queue.put(i)
60+
await asyncio.sleep(0.2)
61+
await subscriber_queue.put(6)
62+
target_mock.assert_has_calls([call([0, 1, 2]), call([3, 4]), call([6])])

0 commit comments

Comments
 (0)