diff --git a/src/neo4j/_async/driver.py b/src/neo4j/_async/driver.py index b60440ef4..1045d7da3 100644 --- a/src/neo4j/_async/driver.py +++ b/src/neo4j/_async/driver.py @@ -18,6 +18,7 @@ from __future__ import annotations +import asyncio import typing as t @@ -466,7 +467,11 @@ def session(self, **config) -> AsyncSession: async def close(self) -> None: """ Shut down, closing any open connections in the pool. """ - await self._pool.close() + try: + await self._pool.close() + except asyncio.CancelledError: + self._closed = True + raise self._closed = True if t.TYPE_CHECKING: diff --git a/src/neo4j/_sync/driver.py b/src/neo4j/_sync/driver.py index 5aace35b7..7ad02ef7d 100644 --- a/src/neo4j/_sync/driver.py +++ b/src/neo4j/_sync/driver.py @@ -18,6 +18,7 @@ from __future__ import annotations +import asyncio import typing as t @@ -463,7 +464,11 @@ def session(self, **config) -> Session: def close(self) -> None: """ Shut down, closing any open connections in the pool. """ - self._pool.close() + try: + self._pool.close() + except asyncio.CancelledError: + self._closed = True + raise self._closed = True if t.TYPE_CHECKING: diff --git a/tests/integration/mixed/test_async_driver.py b/tests/integration/mixed/test_async_driver.py index 8e2571cbc..5f71700a6 100644 --- a/tests/integration/mixed/test_async_driver.py +++ b/tests/integration/mixed/test_async_driver.py @@ -24,6 +24,7 @@ import neo4j from ... import env +from ..._async_compat import mark_async_test # TODO: Python 3.9: when support gets dropped, remove this mark @@ -44,7 +45,7 @@ def test_can_create_async_driver_outside_of_loop(uri, auth): async def return_1(tx: neo4j.AsyncManagedTransaction) -> None: nonlocal counter, was_full - res = await tx.run("RETURN 1") + res = await tx.run("UNWIND range(1, 10000) AS x RETURN x") counter += 1 while not was_full and counter < pool_size: @@ -86,3 +87,55 @@ async def run(driver_: neo4j.AsyncDriver): loop.run_until_complete(coro) finally: loop.close() + + +@mark_async_test +async def test_cancel_driver_close(uri, auth): + class Signal: + queried = False + released = False + + async def fill_pool(driver_: neo4j.AsyncDriver, n=10): + signals = [Signal() for _ in range(n)] + await asyncio.gather( + *(handle_session(driver_.session(), signals[i]) for i in range(n)), + handle_signals(signals), + return_exceptions=True, + ) + + async def handle_signals(signals): + while any(not signal.queried for signal in signals): + await asyncio.sleep(0.001) + await asyncio.sleep(0.1) + for signal in signals: + signal.released = True + + async def handle_session(session, signal): + async with session: + await session.execute_read(work, signal) + + async def work(tx: neo4j.AsyncManagedTransaction, signal: Signal) -> None: + res = await tx.run("UNWIND range(1, 10000) AS x RETURN x") + signal.queried = True + while not signal.released: + await asyncio.sleep(0.001) + await res.consume() + + def connection_count(driver_): + return sum(len(v) for v in driver_._pool.connections.values()) + + driver = neo4j.AsyncGraphDatabase.driver(uri, auth=auth) + await fill_pool(driver) + # sanity check, there should be some connections + assert connection_count(driver) >= 10 + + # start the close and give it some event loop iterations to kick off + fut = asyncio.ensure_future(driver.close()) + await asyncio.sleep(0) + + # cancel in the middle of closing connections + fut.cancel() + # give the driver a chance to close connections forcefully + await asyncio.sleep(0) + # driver should be marked as closed to not emmit a ResourceWarning later + assert driver._closed == True