Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 6 additions & 1 deletion src/neo4j/_async/driver.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@

from __future__ import annotations

import asyncio
import typing as t


Expand Down Expand Up @@ -466,7 +467,11 @@ def session(self, **config) -> AsyncSession:
async def close(self) -> None:
""" Shut down, closing any open connections in the pool.
"""
await self._pool.close()
try:
await self._pool.close()
except asyncio.CancelledError:
self._closed = True
raise
self._closed = True

if t.TYPE_CHECKING:
Expand Down
7 changes: 6 additions & 1 deletion src/neo4j/_sync/driver.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@

from __future__ import annotations

import asyncio
import typing as t


Expand Down Expand Up @@ -463,7 +464,11 @@ def session(self, **config) -> Session:
def close(self) -> None:
""" Shut down, closing any open connections in the pool.
"""
self._pool.close()
try:
self._pool.close()
except asyncio.CancelledError:
self._closed = True
raise
self._closed = True

if t.TYPE_CHECKING:
Expand Down
55 changes: 54 additions & 1 deletion tests/integration/mixed/test_async_driver.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
import neo4j

from ... import env
from ..._async_compat import mark_async_test


# TODO: Python 3.9: when support gets dropped, remove this mark
Expand All @@ -44,7 +45,7 @@ def test_can_create_async_driver_outside_of_loop(uri, auth):

async def return_1(tx):
nonlocal counter, was_full
res = await tx.run("RETURN 1")
res = await tx.run("UNWIND range(1, 10000) AS x RETURN x")

counter += 1
while not was_full and counter < pool_size:
Expand Down Expand Up @@ -92,3 +93,55 @@ async def run(driver: neo4j.AsyncDriver):
loop.run_until_complete(coro)
finally:
loop.close()


@mark_async_test
async def test_cancel_driver_close(uri, auth):
class Signal:
queried = False
released = False

async def fill_pool(driver_: neo4j.AsyncDriver, n=10):
signals = [Signal() for _ in range(n)]
await asyncio.gather(
*(handle_session(driver_.session(), signals[i]) for i in range(n)),
handle_signals(signals),
return_exceptions=True,
)

async def handle_signals(signals):
while any(not signal.queried for signal in signals):
await asyncio.sleep(0.001)
await asyncio.sleep(0.1)
for signal in signals:
signal.released = True

async def handle_session(session, signal):
async with session:
await session.execute_read(work, signal)

async def work(tx: neo4j.AsyncManagedTransaction, signal: Signal) -> None:
res = await tx.run("UNWIND range(1, 10000) AS x RETURN x")
signal.queried = True
while not signal.released:
await asyncio.sleep(0.001)
await res.consume()

def connection_count(driver_):
return sum(len(v) for v in driver_._pool.connections.values())

driver = neo4j.AsyncGraphDatabase.driver(uri, auth=auth)
await fill_pool(driver)
# sanity check, there should be some connections
assert connection_count(driver) >= 10

# start the close and give it some event loop iterations to kick off
fut = asyncio.ensure_future(driver.close())
await asyncio.sleep(0)

# cancel in the middle of closing connections
fut.cancel()
# give the driver a chance to close connections forcefully
await asyncio.sleep(0)
# driver should be marked as closed to not emmit a ResourceWarning later
assert driver._closed == True