Skip to content

Commit debf8a4

Browse files
committed
Clean-up and polish
1 parent 232f0e6 commit debf8a4

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

45 files changed

+794
-77
lines changed

.gitattributes

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
# configure github not to display generated files
22
/src/neo4j/_sync/** linguist-generated=true
3-
/tests/unit/sync_/** linguist-generated=true
4-
/tests/integration/sync_/** linguist-generated=true
3+
/tests/unit/sync/** linguist-generated=true
4+
/tests/integration/sync/** linguist-generated=true
55
/testkitbackend/_sync/** linguist-generated=true

docs/source/api.rst

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -165,7 +165,7 @@ Closing a driver will immediately shut down all connections in the pool.
165165
.. autoclass:: neo4j.Driver()
166166
:members: session, execute_query_bookmark_manager, encrypted, close,
167167
verify_connectivity, get_server_info, verify_authentication,
168-
supports_session_auth, supports_multi_db, force_home_database_resolution
168+
supports_session_auth, supports_multi_db
169169

170170
.. method:: execute_query(query, parameters_=None,routing_=neo4j.RoutingControl.WRITE, database_=None, impersonated_user_=None, bookmark_manager_=self.execute_query_bookmark_manager, auth_=None, result_transformer_=Result.to_eager_result, **kwargs)
171171

src/neo4j/_async/home_db_cache.py

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -60,6 +60,11 @@ def __init__(
6060
f"got {max_size}"
6161
)
6262
self._max_size = max_size
63+
self._truncate_size = (
64+
min(max_size, int(0.01 * max_size * math.log(max_size)))
65+
if max_size is not None
66+
else None
67+
)
6368

6469
def compute_key(
6570
self,
@@ -106,7 +111,9 @@ def _clean(self, now: float | None = None) -> None:
106111
now = monotonic() if now is None else now
107112
if now - self._oldest_entry > self._ttl:
108113
self._cache = {
109-
k: v for k, v in self._cache.items() if now - v[0] < self._ttl
114+
k: v
115+
for k, v in self._cache.items()
116+
if now - v[0] < self._ttl * 0.9
110117
}
111118
self._oldest_entry = min(
112119
(v[0] for v in self._cache.values()), default=now
@@ -117,7 +124,7 @@ def _clean(self, now: float | None = None) -> None:
117124
self._cache.items(),
118125
key=lambda item: item[1][0],
119126
reverse=True,
120-
)[: int(self._max_size * 0.9)]
127+
)[: self._truncate_size]
121128
)
122129

123130
def __len__(self) -> int:

src/neo4j/_async/io/_pool.py

Lines changed: 11 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -105,13 +105,13 @@ def add_connection(self, connection):
105105
def remove_connection(self, connection):
106106
if self.feature_check(connection):
107107
if self.with_feature == 0:
108-
raise ValueError(
108+
raise RuntimeError(
109109
"No connections to be removed from feature tracker"
110110
)
111111
self.with_feature -= 1
112112
else:
113113
if self.without_feature == 0:
114-
raise ValueError(
114+
raise RuntimeError(
115115
"No connections to be removed from feature tracker"
116116
)
117117
self.without_feature -= 1
@@ -143,7 +143,8 @@ def is_direct_pool(self) -> bool: ...
143143

144144
@property
145145
def ssr_enabled(self) -> bool:
146-
return self._ssr_feature_tracker.has_feature
146+
with self.lock:
147+
return self._ssr_feature_tracker.has_feature
147148

148149
async def __aenter__(self):
149150
return self
@@ -601,8 +602,8 @@ async def close(self):
601602
for address in list(self.connections)
602603
for connection in self.connections.pop(address, ())
603604
]
604-
for connection in connections:
605-
self._ssr_feature_tracker.remove_connection(connection)
605+
for connection in connections:
606+
self._ssr_feature_tracker.remove_connection(connection)
606607
await self._close_connections(connections)
607608
except TypeError:
608609
pass
@@ -1012,7 +1013,7 @@ async def update_routing_table(
10121013
log.error("Unable to retrieve routing information")
10131014
raise ServiceUnavailable("Unable to retrieve routing information")
10141015

1015-
async def update_connection_pool(self, *, database):
1016+
async def update_connection_pool(self):
10161017
async with self.refresh_lock:
10171018
routing_tables = list(self.routing_tables.values())
10181019

@@ -1077,12 +1078,14 @@ async def ensure_routing_table_is_fresh(
10771078
)
10781079
return False
10791080

1081+
database_request = database.name if not database.guessed else None
1082+
10801083
async def wrapped_database_callback(database: str | None) -> None:
10811084
await AsyncUtil.callback(database_callback, database)
1082-
await self.update_connection_pool(database=database)
1085+
await self.update_connection_pool()
10831086

10841087
await self.update_routing_table(
1085-
database=database.name if not database.guessed else None,
1088+
database=database_request,
10861089
imp_user=imp_user,
10871090
bookmarks=bookmarks,
10881091
auth=auth,

src/neo4j/_async/work/workspace.py

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -104,14 +104,15 @@ async def __aenter__(self) -> AsyncWorkspace:
104104
async def __aexit__(self, exc_type, exc_value, traceback):
105105
await self.close()
106106

107-
def _make_db_resolution_callback(self) -> t.Callable[[str], None] | None:
107+
def _make_db_resolution_callback(
108+
self,
109+
) -> t.Callable[[str | None], None] | None:
108110
if self._pinned_database:
109111
return None
110112

111113
def _database_callback(database: str | None) -> None:
112-
if not self._pinned_database:
113-
self._set_pinned_database(database)
114-
if self._last_cache_key is None:
114+
self._set_pinned_database(database)
115+
if self._last_cache_key is None or database is None:
115116
return
116117
db_cache: AsyncHomeDbCache = self._pool.home_db_cache
117118
db_cache.set(self._last_cache_key, database)
@@ -206,7 +207,7 @@ async def _connect(self, access_mode, auth=None, **acquire_kwargs) -> None:
206207
not self._pool.ssr_enabled or not self._connection.ssr_enabled
207208
)
208209
):
209-
# race condition: in the meantime, the pool added a connection,
210+
# race condition: in the meantime the pool added a connection
210211
# which does not support SSR.
211212
# => we need to fall back to explicit home database resolution
212213
log.debug(

src/neo4j/_sync/home_db_cache.py

Lines changed: 9 additions & 2 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

src/neo4j/_sync/io/_pool.py

Lines changed: 11 additions & 8 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

src/neo4j/_sync/work/workspace.py

Lines changed: 6 additions & 5 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

tests/_async_compat/__init__.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,8 @@
1414
# limitations under the License.
1515

1616

17+
from functools import wraps as _wraps
18+
1719
from .mark_decorator import (
1820
AsyncTestDecorators,
1921
mark_async_test,
@@ -27,4 +29,15 @@
2729
"TestDecorators",
2830
"mark_async_test",
2931
"mark_sync_test",
32+
"wrap_async",
3033
]
34+
35+
36+
def wrap_async(func):
37+
@_wraps(func)
38+
async def wrapper(*args, **kwargs): # noqa: RUF029
39+
# [noqa] the hole point of this wrapper is to turn a sync function into
40+
# an async one for testing purposes
41+
return func(*args, **kwargs)
42+
43+
return wrapper

tests/unit/async_/io/test_class_bolt3.py

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -569,3 +569,25 @@ def on_success(metadata):
569569
await connection.fetch_all()
570570

571571
assert received_metadata == sent_metadata
572+
573+
574+
@pytest.mark.parametrize("ssr_hint", (True, False, None))
575+
@mark_async_test
576+
async def test_ssr_enabled(ssr_hint, fake_socket_pair):
577+
address = neo4j.Address(("127.0.0.1", 7687))
578+
sockets = fake_socket_pair(
579+
address,
580+
packer_cls=AsyncBolt3.PACKER_CLS,
581+
unpacker_cls=AsyncBolt3.UNPACKER_CLS,
582+
)
583+
meta = {"server": "Neo4j/4.3.4"}
584+
if ssr_hint is not None:
585+
meta["hints"] = {"ssr.enabled": ssr_hint}
586+
await sockets.server.send_message(b"\x70", meta)
587+
await sockets.server.send_message(b"\x70", {})
588+
connection = AsyncBolt3(
589+
address, sockets.client, AsyncPoolConfig.max_connection_lifetime
590+
)
591+
assert connection.ssr_enabled is False
592+
await connection.hello()
593+
assert connection.ssr_enabled is False

0 commit comments

Comments
 (0)