diff --git a/asyncio_pool/base_pool.py b/asyncio_pool/base_pool.py index c96bcff..7056f21 100644 --- a/asyncio_pool/base_pool.py +++ b/asyncio_pool/base_pool.py @@ -238,6 +238,9 @@ async def map(self, fn, iterable, cb=None, ctx=None, *, fut = await self.spawn(fn(it), cb, ctx) futures.append(fut) + if not futures: + return [] + await aio.wait(futures) return [get_result(fut) for fut in futures] diff --git a/tests/test_map.py b/tests/test_map.py index 061f3ff..113fb37 100644 --- a/tests/test_map.py +++ b/tests/test_map.py @@ -16,6 +16,14 @@ async def test_map_simple(): assert res == [i*10 for i in task] +@pytest.mark.asyncio +async def test_map_empty(): + task = [] + pool = AioPool(size=7) + res = await pool.map(wrk, task) + assert res == [] + + @pytest.mark.asyncio async def test_map_crash(): task = range(5) @@ -52,6 +60,17 @@ async def wrk(n): i += 1 # does not support enumerate btw ( +@pytest.mark.asyncio +async def test_itermap_empty(): + async def wrk(n): + await aio.sleep(n) + return n + + async with AioPool(size=3) as pool: + async for res in pool.itermap(wrk, [], flat=False, timeout=0.6): + assert False + + @pytest.mark.asyncio async def test_itermap_cancel():