From 2fd52e5c01faaadfb04552c56b6b78f9a24a2d56 Mon Sep 17 00:00:00 2001 From: Robsdedude Date: Fri, 28 Apr 2023 07:51:25 +0200 Subject: [PATCH 01/26] Rebase of branch 'home-db-cache' without config option --- docs/source/api.rst | 46 +++++++++--- docs/source/async_api.rst | 6 +- src/neo4j/_async/driver.py | 28 ++++++- src/neo4j/_async/home_db_cache.py | 117 +++++++++++++++++++++++++++++ src/neo4j/_async/io/_bolt.py | 23 +----- src/neo4j/_async/io/_pool.py | 5 +- src/neo4j/_async/work/workspace.py | 103 ++++++++++++++++++++----- src/neo4j/_auth_management.py | 18 +++++ src/neo4j/_sync/driver.py | 28 ++++++- src/neo4j/_sync/home_db_cache.py | 117 +++++++++++++++++++++++++++++ src/neo4j/_sync/io/_bolt.py | 23 +----- src/neo4j/_sync/io/_pool.py | 5 +- src/neo4j/_sync/work/workspace.py | 100 +++++++++++++++++++----- testkitbackend/test_config.json | 1 + 14 files changed, 530 insertions(+), 90 deletions(-) create mode 100644 src/neo4j/_async/home_db_cache.py create mode 100644 src/neo4j/_sync/home_db_cache.py diff --git a/docs/source/api.rst b/docs/source/api.rst index 35041a85..fc618e33 100644 --- a/docs/source/api.rst +++ b/docs/source/api.rst @@ -165,7 +165,7 @@ Closing a driver will immediately shut down all connections in the pool. .. autoclass:: neo4j.Driver() :members: session, execute_query_bookmark_manager, encrypted, close, verify_connectivity, get_server_info, verify_authentication, - supports_session_auth, supports_multi_db + supports_session_auth, supports_multi_db, force_home_database_resolution .. method:: execute_query(query, parameters_=None,routing_=neo4j.RoutingControl.WRITE, database_=None, impersonated_user_=None, bookmark_manager_=self.execute_query_bookmark_manager, auth_=None, result_transformer_=Result.to_eager_result, **kwargs) @@ -260,7 +260,11 @@ Closing a driver will immediately shut down all connections in the pool. :param database\_: Database to execute the query against. - None (default) uses the database configured on the server side. + :data:`None` (default) uses the database configured on the server + side. + Depending on the :ref:`max-home-database-delay-ref` configuration, + propagation of changes to the server side default might not be + immediate. .. Note:: It is recommended to always specify the database explicitly @@ -399,6 +403,7 @@ Additional configuration can be provided via the :class:`neo4j.Driver` construct + :ref:`liveness-check-timeout-ref` + :ref:`max-connection-pool-size-ref` + :ref:`max-transaction-retry-time-ref` ++ :ref:`max-home-database-delay-ref` + :ref:`resolver-ref` + :ref:`trust-ref` + :ref:`ssl-context-ref` @@ -521,6 +526,26 @@ The maximum total number of connections allowed, per host (i.e. cluster nodes), :Default: ``30.0`` +.. _max-home-database-delay-ref: + +``max_home_database_delay`` +--------------------------- +Defines an upper bound for how long (in seconds) a resolved home database can be cached. + +Set this value to ``0`` to prohibit any caching. +This likely incurs a significant performance penalty (driver and server side). +Set this value to ``float("inf")`` to allow the driver to cache resolutions forever. + +Note that in future driver/protocol versions, this setting might have no effect. + +:Type: ``float`` +:Default: ``5.0`` + +.. versionadded:: 5.x + +.. seealso:: :meth:`Driver.force_home_database_resolution` + + .. _resolver-ref: ``resolver`` @@ -1035,13 +1060,16 @@ Specifically, the following applies: instance, if the user's home database name is 'movies' and the server supplies it to the driver upon database name fetching for the session, all queries within that session are executed with the explicit database - name 'movies' supplied. Any change to the user’s home database is - reflected only in sessions created after such change takes effect. This - behavior requires additional network communication. In clustered - environments, it is strongly recommended to avoid a single point of - failure. For instance, by ensuring that the connection URI resolves to - multiple endpoints. For older Bolt protocol versions the behavior is the - same as described for the **bolt schemes** above. + name 'movies' supplied. Changes to the user's home database will only be + picked up by future sessions. There might be an additional delay depending + on the :ref:`max-home-database-delay-ref` configuration. Resolving the + user's home database name requires additional network communication. + Therefore, it is either recommended to either specify the database name + explicitly or set the home database delay appropriately. + In clustered environments, it is strongly recommended to avoid a single + point of failure. For instance, by ensuring that the connection URI + resolves to multiple endpoints. For older Bolt protocol versions the + behavior is the same as described for the **bolt schemes** above. .. code-block:: python diff --git a/docs/source/async_api.rst b/docs/source/async_api.rst index 483b70e7..8dc40730 100644 --- a/docs/source/async_api.rst +++ b/docs/source/async_api.rst @@ -247,7 +247,11 @@ Closing a driver will immediately shut down all connections in the pool. :param database\_: Database to execute the query against. - None (default) uses the database configured on the server side. + :data:`None` (default) uses the database configured on the server + side. + Depending on the :ref:`max-home-database-delay-ref` configuration, + propagation of changes to the server side default might not be + immediate. .. Note:: It is recommended to always specify the database explicitly diff --git a/src/neo4j/_async/driver.py b/src/neo4j/_async/driver.py index 627fa14d..6dac8697 100644 --- a/src/neo4j/_async/driver.py +++ b/src/neo4j/_async/driver.py @@ -801,7 +801,11 @@ async def example(driver: neo4j.AsyncDriver) -> int: :param database_: Database to execute the query against. - None (default) uses the database configured on the server side. + :data:`None` (default) uses the database configured on the server + side. + Depending on the :ref:`max-home-database-delay-ref` configuration, + propagation of changes to the server side default might not be + immediate. .. Note:: It is recommended to always specify the database explicitly @@ -1298,6 +1302,28 @@ async def _get_server_info(self, session_config) -> ServerInfo: async with self._session(session_config) as session: return await session._get_server_info() + def force_home_database_resolution(self) -> None: + """Force the driver to resolve all home databases (again). + + The resolution is lazy and will only happen when the driver needs to + know the home database. + In practice, this means that the driver will flush the cache + configured by `max_home_database_delay`. + + This method is for instance useful when an application has changed a + user's home database, and the same application wants to pick up the + change in the next session while wanting to avoid setting + `max_home_database_delay` to `0` because of the performance penalty. + + .. versionadded:: 5.x + + .. seealso:: + Driver config :ref:`max-home-database-delay-ref` + """ + home_db_cache = self._pool.home_db_cache + if home_db_cache.enabled: + home_db_cache.clear() + async def _work( tx: AsyncManagedTransaction, diff --git a/src/neo4j/_async/home_db_cache.py b/src/neo4j/_async/home_db_cache.py new file mode 100644 index 00000000..5e174094 --- /dev/null +++ b/src/neo4j/_async/home_db_cache.py @@ -0,0 +1,117 @@ +# Copyright (c) "Neo4j" +# Neo4j Sweden AB [https://neo4j.com] +# +# This file is part of Neo4j. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +from __future__ import annotations + +import heapq +import math +import typing as t +from time import monotonic + +from .._async_compat.concurrency import AsyncCooperativeLock + + +if t.TYPE_CHECKING: + # TAuthKey = t.Tuple[t.Tuple[]] + TKey = str | tuple[tuple[str, t.Hashable], ...] | None + TVal = tuple[float, str] + + +class AsyncHomeDbCache: + _ttl: float + _enabled: bool + _max_size: int | None + + def __init__( + self, + ttl: float = float("inf"), + enabled: bool = True, + max_size: int | None = None, + ) -> None: + if math.isnan(ttl) or ttl <= 0: + raise ValueError("home db cache ttl must be greater 0") + self._enabled = enabled + self._ttl = ttl + self._cache: dict[TKey, TVal] = {} + self._lock = AsyncCooperativeLock() + self._last_clean = monotonic() + + def compute_key( + self, + imp_user: str | None, + auth: dict | None, + ) -> TKey: + if not self._enabled: + return None + if imp_user is not None: + return imp_user + if auth is not None: + return _hashable_dict(auth) + return None + + def get(self, key: TKey) -> str | None: + with self._lock: + val = self._cache.get(key) + if val is None: + return None + now = monotonic() + if now - val[0] > self._ttl: + del self._cache[key] + return None + # Saved some time with a cache hit, + # so we can waste some with cleaning the cache ;) + self._clean(now) + return val[1] + + def set(self, key: TKey, value: str | None) -> None: + with self._lock: + if value is None: + self._cache.pop(key, None) + else: + self._cache[key] = (monotonic(), value) + + def clear(self) -> None: + with self._lock: + self._cache = {} + self._last_clean = monotonic() + + def _clean(self, now: float) -> None: + if self._max_size is not None and len(self._cache) > self._max_size: + self._cache = dict( + heapq.nlargest( + self._max_size, + self._cache.items(), + key=lambda item: item[1][0], + ) + ) + if now - self._last_clean > self._ttl: + self._cache = { + k: v for k, v in self._cache.items() if now - v[0] < self._ttl + } + self._last_clean = now + + @property + def enabled(self) -> bool: + return self._enabled + + +def _hashable_dict(d: dict) -> tuple: + return tuple( + (k, _hashable_dict(v) if isinstance(v, dict) else v) + for k, v in sorted(d.items()) + ) diff --git a/src/neo4j/_async/io/_bolt.py b/src/neo4j/_async/io/_bolt.py index 339c065f..d973b2ac 100644 --- a/src/neo4j/_async/io/_bolt.py +++ b/src/neo4j/_async/io/_bolt.py @@ -25,6 +25,7 @@ from ..._async_compat.network import AsyncBoltSocket from ..._async_compat.util import AsyncUtil +from ..._auth_management import to_auth_dict from ..._codec.hydration import ( HydrationHandlerABC, v1 as hydration_v1, @@ -39,12 +40,10 @@ from ..._sync.config import PoolConfig from ...addressing import ResolvedAddress from ...api import ( - Auth, ServerInfo, Version, ) from ...exceptions import ( - AuthError, ConfigurationError, DriverError, IncompleteCommit, @@ -187,7 +186,7 @@ def __init__( self.user_agent = USER_AGENT self.auth = auth - self.auth_dict = self._to_auth_dict(auth) + self.auth_dict = to_auth_dict(auth) self.auth_manager = auth_manager self.telemetry_disabled = telemetry_disabled @@ -206,22 +205,6 @@ def _get_server_state_manager(self) -> ServerStateManagerBase: ... @abc.abstractmethod def _get_client_state_manager(self) -> ClientStateManagerBase: ... - @classmethod - def _to_auth_dict(cls, auth): - # Determine auth details - if not auth: - return {} - elif isinstance(auth, tuple) and 2 <= len(auth) <= 3: - return vars(Auth("basic", *auth)) - else: - try: - return vars(auth) - except (KeyError, TypeError) as e: - # TODO: 6.0 - change this to be a DriverError (or subclass) - raise AuthError( - f"Cannot determine auth details from {auth!r}" - ) from e - @property def connection_id(self): return self.server_info._metadata.get("connection_id", "") @@ -626,7 +609,7 @@ def re_auth( :returns: whether the auth was changed """ - new_auth_dict = self._to_auth_dict(auth) + new_auth_dict = to_auth_dict(auth) if not force and new_auth_dict == self.auth_dict: self.auth_manager = auth_manager self.auth = auth diff --git a/src/neo4j/_async/io/_pool.py b/src/neo4j/_async/io/_pool.py index 7a520abe..38e5900b 100644 --- a/src/neo4j/_async/io/_pool.py +++ b/src/neo4j/_async/io/_pool.py @@ -59,6 +59,7 @@ WriteServiceUnavailable, ) from ..config import AsyncPoolConfig +from ..home_db_cache import AsyncHomeDbCache from ._bolt import AsyncBolt @@ -94,6 +95,7 @@ def __init__(self, opener, pool_config, workspace_config): self.connections_reservations = defaultdict(lambda: 0) self.lock = AsyncCooperativeRLock() self.cond = AsyncCondition(self.lock) + self.home_db_cache = AsyncHomeDbCache(max_size=10_000) @property @abc.abstractmethod @@ -853,8 +855,7 @@ async def _update_routing_table_from( address, self.routing_tables[new_database], ) - if callable(database_callback): - database_callback(new_database) + await AsyncUtil.callback(database_callback, new_database) return True await self.deactivate(router) return False diff --git a/src/neo4j/_async/work/workspace.py b/src/neo4j/_async/work/workspace.py index dc044249..7424a3cc 100644 --- a/src/neo4j/_async/work/workspace.py +++ b/src/neo4j/_async/work/workspace.py @@ -17,8 +17,10 @@ from __future__ import annotations import logging +import typing as t from ..._async_compat.util import AsyncUtil +from ..._auth_management import to_auth_dict from ..._conf import WorkspaceConfig from ..._meta import ( deprecation_warn, @@ -37,6 +39,20 @@ ) +if t.TYPE_CHECKING: + from ...api import _TAuth + from ...auth_management import ( + AsyncAuthManager, + AuthManager, + ) + from ..home_db_cache import ( + AsyncHomeDbCache, + TKey, + ) +else: + _TAuth = t.Any + + log = logging.getLogger("neo4j") @@ -87,6 +103,18 @@ async def __aenter__(self) -> AsyncWorkspace: async def __aexit__(self, exc_type, exc_value, traceback): await self.close() + def _make_database_callback( + self, + cache_key: TKey, + ) -> t.Callable[[str], None]: + def _database_callback(database: str | None) -> None: + db_cache: AsyncHomeDbCache = self._pool.home_db_cache + if db_cache.enabled: + db_cache.set(cache_key, database) + self._set_cached_database(database) + + return _database_callback + def _set_cached_database(self, database): self._cached_database = True self._config.database = database @@ -140,10 +168,8 @@ async def _update_bookmark(self, bookmark): async def _connect(self, access_mode, auth=None, **acquire_kwargs): acquisition_timeout = self._config.connection_acquisition_timeout - auth = AcquireAuth( - auth, - force_auth=acquire_kwargs.pop("force_auth", False), - ) + force_auth = acquire_kwargs.pop("force_auth", False) + acquire_auth = AcquireAuth(auth, force_auth=force_auth) if self._connection: # TODO: Investigate this @@ -151,6 +177,22 @@ async def _connect(self, access_mode, auth=None, **acquire_kwargs): await self._connection.send_all() await self._connection.fetch_all() await self._disconnect() + await self._fill_cached_database(acquire_auth) + acquire_kwargs_ = { + "access_mode": access_mode, + "timeout": acquisition_timeout, + "database": self._config.database, + "bookmarks": await self._get_bookmarks(), + "auth": acquire_auth, + "liveness_check_timeout": None, + } + acquire_kwargs_.update(acquire_kwargs) + self._connection = await self._pool.acquire(**acquire_kwargs_) + self._connection_access_mode = access_mode + + async def _fill_cached_database(self, acquire_auth: AcquireAuth) -> None: + auth = acquire_auth.auth + acquisition_timeout = self._config.connection_acquisition_timeout if not self._cached_database: if self._config.database is not None or not isinstance( self._pool, AsyncNeo4jPool @@ -163,26 +205,53 @@ async def _connect(self, access_mode, auth=None, **acquire_kwargs): # to try to fetch the home database. If provided by the server, # we shall use this database explicitly for all subsequent # actions within this session. + # Unless we have the resolved home db in out cache: + + db_cache: AsyncHomeDbCache = self._pool.home_db_cache + cache_key = cached_db = None + if db_cache.enabled: + cache_key = db_cache.compute_key( + self._config.impersonated_user, + await self._resolve_session_auth(auth), + ) + cached_db = db_cache.get(cache_key) + if cached_db is not None: + log.debug( + ( + "[#0000] _: resolved home database " + "from cache: %s" + ), + cached_db, + ) + self._set_cached_database(cached_db) + return log.debug("[#0000] _: resolve home database") await self._pool.update_routing_table( database=self._config.database, imp_user=self._config.impersonated_user, bookmarks=await self._get_bookmarks(), - auth=auth, + auth=acquire_auth, acquisition_timeout=acquisition_timeout, - database_callback=self._set_cached_database, + database_callback=self._make_database_callback(cache_key), ) - acquire_kwargs_ = { - "access_mode": access_mode, - "timeout": acquisition_timeout, - "database": self._config.database, - "bookmarks": await self._get_bookmarks(), - "auth": auth, - "liveness_check_timeout": None, - } - acquire_kwargs_.update(acquire_kwargs) - self._connection = await self._pool.acquire(**acquire_kwargs_) - self._connection_access_mode = access_mode + + @staticmethod + async def _resolve_session_auth( + auth: AsyncAuthManager | AuthManager | None, + ) -> dict | None: + if auth is None: + return None + # resolved_auth = await AsyncUtil.callback(auth.get_auth) + # The above line breaks mypy + # https://github.com/python/mypy/issues/15295 + auth_getter: t.Callable[[], _TAuth | t.Awaitable[_TAuth]] = ( + auth.get_auth + ) + # so we enforce the right type here + # (explicit type annotation above added as it's a necessary assumption + # for this cast to be correct) + resolved_auth = t.cast(_TAuth, await AsyncUtil.callback(auth_getter)) + return to_auth_dict(resolved_auth) async def _disconnect(self, sync=False): if self._connection: diff --git a/src/neo4j/_auth_management.py b/src/neo4j/_auth_management.py index 409d9ba9..46ddf901 100644 --- a/src/neo4j/_auth_management.py +++ b/src/neo4j/_auth_management.py @@ -22,6 +22,8 @@ from dataclasses import dataclass from ._meta import preview +from .api import Auth +from .exceptions import AuthError if t.TYPE_CHECKING: @@ -321,3 +323,19 @@ async def get_certificate(self) -> ClientCertificate | None: .. seealso:: :meth:`.ClientCertificateProvider.get_certificate` """ ... + + +def to_auth_dict(auth: _TAuth) -> dict[str, t.Any]: + # Determine auth details + if not auth: + return {} + elif isinstance(auth, tuple) and 2 <= len(auth) <= 3: + return vars(Auth("basic", *auth)) + else: + try: + return vars(auth) + except (KeyError, TypeError) as e: + # TODO: 6.0 - change this to be a DriverError (or subclass) + raise AuthError( + f"Cannot determine auth details from {auth!r}" + ) from e diff --git a/src/neo4j/_sync/driver.py b/src/neo4j/_sync/driver.py index 3b205a34..af947f6e 100644 --- a/src/neo4j/_sync/driver.py +++ b/src/neo4j/_sync/driver.py @@ -800,7 +800,11 @@ def example(driver: neo4j.Driver) -> int: :param database_: Database to execute the query against. - None (default) uses the database configured on the server side. + :data:`None` (default) uses the database configured on the server + side. + Depending on the :ref:`max-home-database-delay-ref` configuration, + propagation of changes to the server side default might not be + immediate. .. Note:: It is recommended to always specify the database explicitly @@ -1297,6 +1301,28 @@ def _get_server_info(self, session_config) -> ServerInfo: with self._session(session_config) as session: return session._get_server_info() + def force_home_database_resolution(self) -> None: + """Force the driver to resolve all home databases (again). + + The resolution is lazy and will only happen when the driver needs to + know the home database. + In practice, this means that the driver will flush the cache + configured by `max_home_database_delay`. + + This method is for instance useful when an application has changed a + user's home database, and the same application wants to pick up the + change in the next session while wanting to avoid setting + `max_home_database_delay` to `0` because of the performance penalty. + + .. versionadded:: 5.x + + .. seealso:: + Driver config :ref:`max-home-database-delay-ref` + """ + home_db_cache = self._pool.home_db_cache + if home_db_cache.enabled: + home_db_cache.clear() + def _work( tx: ManagedTransaction, diff --git a/src/neo4j/_sync/home_db_cache.py b/src/neo4j/_sync/home_db_cache.py new file mode 100644 index 00000000..8f8dd21f --- /dev/null +++ b/src/neo4j/_sync/home_db_cache.py @@ -0,0 +1,117 @@ +# Copyright (c) "Neo4j" +# Neo4j Sweden AB [https://neo4j.com] +# +# This file is part of Neo4j. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +from __future__ import annotations + +import heapq +import math +import typing as t +from time import monotonic + +from .._async_compat.concurrency import CooperativeLock + + +if t.TYPE_CHECKING: + # TAuthKey = t.Tuple[t.Tuple[]] + TKey = str | tuple[tuple[str, t.Hashable], ...] | None + TVal = tuple[float, str] + + +class HomeDbCache: + _ttl: float + _enabled: bool + _max_size: int | None + + def __init__( + self, + ttl: float = float("inf"), + enabled: bool = True, + max_size: int | None = None, + ) -> None: + if math.isnan(ttl) or ttl <= 0: + raise ValueError("home db cache ttl must be greater 0") + self._enabled = enabled + self._ttl = ttl + self._cache: dict[TKey, TVal] = {} + self._lock = CooperativeLock() + self._last_clean = monotonic() + + def compute_key( + self, + imp_user: str | None, + auth: dict | None, + ) -> TKey: + if not self._enabled: + return None + if imp_user is not None: + return imp_user + if auth is not None: + return _hashable_dict(auth) + return None + + def get(self, key: TKey) -> str | None: + with self._lock: + val = self._cache.get(key) + if val is None: + return None + now = monotonic() + if now - val[0] > self._ttl: + del self._cache[key] + return None + # Saved some time with a cache hit, + # so we can waste some with cleaning the cache ;) + self._clean(now) + return val[1] + + def set(self, key: TKey, value: str | None) -> None: + with self._lock: + if value is None: + self._cache.pop(key, None) + else: + self._cache[key] = (monotonic(), value) + + def clear(self) -> None: + with self._lock: + self._cache = {} + self._last_clean = monotonic() + + def _clean(self, now: float) -> None: + if self._max_size is not None and len(self._cache) > self._max_size: + self._cache = dict( + heapq.nlargest( + self._max_size, + self._cache.items(), + key=lambda item: item[1][0], + ) + ) + if now - self._last_clean > self._ttl: + self._cache = { + k: v for k, v in self._cache.items() if now - v[0] < self._ttl + } + self._last_clean = now + + @property + def enabled(self) -> bool: + return self._enabled + + +def _hashable_dict(d: dict) -> tuple: + return tuple( + (k, _hashable_dict(v) if isinstance(v, dict) else v) + for k, v in sorted(d.items()) + ) diff --git a/src/neo4j/_sync/io/_bolt.py b/src/neo4j/_sync/io/_bolt.py index f1176ba0..137aa428 100644 --- a/src/neo4j/_sync/io/_bolt.py +++ b/src/neo4j/_sync/io/_bolt.py @@ -25,6 +25,7 @@ from ..._async_compat.network import BoltSocket from ..._async_compat.util import Util +from ..._auth_management import to_auth_dict from ..._codec.hydration import ( HydrationHandlerABC, v1 as hydration_v1, @@ -39,12 +40,10 @@ from ..._sync.config import PoolConfig from ...addressing import ResolvedAddress from ...api import ( - Auth, ServerInfo, Version, ) from ...exceptions import ( - AuthError, ConfigurationError, DriverError, IncompleteCommit, @@ -187,7 +186,7 @@ def __init__( self.user_agent = USER_AGENT self.auth = auth - self.auth_dict = self._to_auth_dict(auth) + self.auth_dict = to_auth_dict(auth) self.auth_manager = auth_manager self.telemetry_disabled = telemetry_disabled @@ -206,22 +205,6 @@ def _get_server_state_manager(self) -> ServerStateManagerBase: ... @abc.abstractmethod def _get_client_state_manager(self) -> ClientStateManagerBase: ... - @classmethod - def _to_auth_dict(cls, auth): - # Determine auth details - if not auth: - return {} - elif isinstance(auth, tuple) and 2 <= len(auth) <= 3: - return vars(Auth("basic", *auth)) - else: - try: - return vars(auth) - except (KeyError, TypeError) as e: - # TODO: 6.0 - change this to be a DriverError (or subclass) - raise AuthError( - f"Cannot determine auth details from {auth!r}" - ) from e - @property def connection_id(self): return self.server_info._metadata.get("connection_id", "") @@ -626,7 +609,7 @@ def re_auth( :returns: whether the auth was changed """ - new_auth_dict = self._to_auth_dict(auth) + new_auth_dict = to_auth_dict(auth) if not force and new_auth_dict == self.auth_dict: self.auth_manager = auth_manager self.auth = auth diff --git a/src/neo4j/_sync/io/_pool.py b/src/neo4j/_sync/io/_pool.py index 1570e745..8ea85586 100644 --- a/src/neo4j/_sync/io/_pool.py +++ b/src/neo4j/_sync/io/_pool.py @@ -59,6 +59,7 @@ WriteServiceUnavailable, ) from ..config import PoolConfig +from ..home_db_cache import HomeDbCache from ._bolt import Bolt @@ -91,6 +92,7 @@ def __init__(self, opener, pool_config, workspace_config): self.connections_reservations = defaultdict(lambda: 0) self.lock = CooperativeRLock() self.cond = Condition(self.lock) + self.home_db_cache = HomeDbCache(max_size=10_000) @property @abc.abstractmethod @@ -850,8 +852,7 @@ def _update_routing_table_from( address, self.routing_tables[new_database], ) - if callable(database_callback): - database_callback(new_database) + Util.callback(database_callback, new_database) return True self.deactivate(router) return False diff --git a/src/neo4j/_sync/work/workspace.py b/src/neo4j/_sync/work/workspace.py index 55ca883d..7a30cb65 100644 --- a/src/neo4j/_sync/work/workspace.py +++ b/src/neo4j/_sync/work/workspace.py @@ -17,8 +17,10 @@ from __future__ import annotations import logging +import typing as t from ..._async_compat.util import Util +from ..._auth_management import to_auth_dict from ..._conf import WorkspaceConfig from ..._meta import ( deprecation_warn, @@ -37,6 +39,17 @@ ) +if t.TYPE_CHECKING: + from ...api import _TAuth + from ...auth_management import AuthManager + from ..home_db_cache import ( + HomeDbCache, + TKey, + ) +else: + _TAuth = t.Any + + log = logging.getLogger("neo4j") @@ -87,6 +100,18 @@ def __enter__(self) -> Workspace: def __exit__(self, exc_type, exc_value, traceback): self.close() + def _make_database_callback( + self, + cache_key: TKey, + ) -> t.Callable[[str], None]: + def _database_callback(database: str | None) -> None: + db_cache: HomeDbCache = self._pool.home_db_cache + if db_cache.enabled: + db_cache.set(cache_key, database) + self._set_cached_database(database) + + return _database_callback + def _set_cached_database(self, database): self._cached_database = True self._config.database = database @@ -140,10 +165,8 @@ def _update_bookmark(self, bookmark): def _connect(self, access_mode, auth=None, **acquire_kwargs): acquisition_timeout = self._config.connection_acquisition_timeout - auth = AcquireAuth( - auth, - force_auth=acquire_kwargs.pop("force_auth", False), - ) + force_auth = acquire_kwargs.pop("force_auth", False) + acquire_auth = AcquireAuth(auth, force_auth=force_auth) if self._connection: # TODO: Investigate this @@ -151,6 +174,22 @@ def _connect(self, access_mode, auth=None, **acquire_kwargs): self._connection.send_all() self._connection.fetch_all() self._disconnect() + self._fill_cached_database(acquire_auth) + acquire_kwargs_ = { + "access_mode": access_mode, + "timeout": acquisition_timeout, + "database": self._config.database, + "bookmarks": self._get_bookmarks(), + "auth": acquire_auth, + "liveness_check_timeout": None, + } + acquire_kwargs_.update(acquire_kwargs) + self._connection = self._pool.acquire(**acquire_kwargs_) + self._connection_access_mode = access_mode + + def _fill_cached_database(self, acquire_auth: AcquireAuth) -> None: + auth = acquire_auth.auth + acquisition_timeout = self._config.connection_acquisition_timeout if not self._cached_database: if self._config.database is not None or not isinstance( self._pool, Neo4jPool @@ -163,26 +202,53 @@ def _connect(self, access_mode, auth=None, **acquire_kwargs): # to try to fetch the home database. If provided by the server, # we shall use this database explicitly for all subsequent # actions within this session. + # Unless we have the resolved home db in out cache: + + db_cache: HomeDbCache = self._pool.home_db_cache + cache_key = cached_db = None + if db_cache.enabled: + cache_key = db_cache.compute_key( + self._config.impersonated_user, + self._resolve_session_auth(auth), + ) + cached_db = db_cache.get(cache_key) + if cached_db is not None: + log.debug( + ( + "[#0000] _: resolved home database " + "from cache: %s" + ), + cached_db, + ) + self._set_cached_database(cached_db) + return log.debug("[#0000] _: resolve home database") self._pool.update_routing_table( database=self._config.database, imp_user=self._config.impersonated_user, bookmarks=self._get_bookmarks(), - auth=auth, + auth=acquire_auth, acquisition_timeout=acquisition_timeout, - database_callback=self._set_cached_database, + database_callback=self._make_database_callback(cache_key), ) - acquire_kwargs_ = { - "access_mode": access_mode, - "timeout": acquisition_timeout, - "database": self._config.database, - "bookmarks": self._get_bookmarks(), - "auth": auth, - "liveness_check_timeout": None, - } - acquire_kwargs_.update(acquire_kwargs) - self._connection = self._pool.acquire(**acquire_kwargs_) - self._connection_access_mode = access_mode + + @staticmethod + def _resolve_session_auth( + auth: AuthManager | AuthManager | None, + ) -> dict | None: + if auth is None: + return None + # resolved_auth = await AsyncUtil.callback(auth.get_auth) + # The above line breaks mypy + # https://github.com/python/mypy/issues/15295 + auth_getter: t.Callable[[], _TAuth | t.Union[_TAuth]] = ( + auth.get_auth + ) + # so we enforce the right type here + # (explicit type annotation above added as it's a necessary assumption + # for this cast to be correct) + resolved_auth = t.cast(_TAuth, Util.callback(auth_getter)) + return to_auth_dict(resolved_auth) def _disconnect(self, sync=False): if self._connection: diff --git a/testkitbackend/test_config.json b/testkitbackend/test_config.json index bca7f0ca..b4b357a6 100644 --- a/testkitbackend/test_config.json +++ b/testkitbackend/test_config.json @@ -61,6 +61,7 @@ "Feature:Bolt:5.7": true, "Feature:Bolt:Patch:UTC": true, "Feature:Impersonation": true, + "Feature:HomeDbCache": true, "Feature:TLS:1.1": "Driver blocks TLS 1.1 for security reasons.", "Feature:TLS:1.2": true, "Feature:TLS:1.3": "Depends on the machine (will be calculated dynamically).", From 4153be60977ff175632f93ef6a62c1f54b3eadeb Mon Sep 17 00:00:00 2001 From: Robsdedude Date: Thu, 31 Oct 2024 08:47:57 +0100 Subject: [PATCH 02/26] WIP --- src/neo4j/_async/home_db_cache.py | 17 +- src/neo4j/_async/io/_bolt.py | 16 +- src/neo4j/_async/io/_bolt3.py | 2 + src/neo4j/_async/io/_bolt4.py | 8 +- src/neo4j/_async/io/_bolt5.py | 32 +- src/neo4j/_async/io/_pool.py | 59 +- src/neo4j/_async/work/result.py | 7 +- src/neo4j/_async/work/session.py | 2 + src/neo4j/_async/work/transaction.py | 9 + src/neo4j/_async/work/workspace.py | 145 ++-- src/neo4j/_sync/home_db_cache.py | 17 +- src/neo4j/_sync/io/_bolt.py | 16 +- src/neo4j/_sync/io/_bolt3.py | 2 + src/neo4j/_sync/io/_bolt4.py | 8 +- src/neo4j/_sync/io/_bolt5.py | 32 +- src/neo4j/_sync/io/_pool.py | 59 +- src/neo4j/_sync/work/result.py | 5 + src/neo4j/_sync/work/session.py | 2 + src/neo4j/_sync/work/transaction.py | 9 + src/neo4j/_sync/work/workspace.py | 145 ++-- tests/unit/async_/io/test_class_bolt.py | 10 +- tests/unit/async_/io/test_class_bolt3.py | 2 +- tests/unit/async_/io/test_class_bolt4x0.py | 2 +- tests/unit/async_/io/test_class_bolt4x1.py | 2 +- tests/unit/async_/io/test_class_bolt4x2.py | 2 +- tests/unit/async_/io/test_class_bolt4x3.py | 2 +- tests/unit/async_/io/test_class_bolt4x4.py | 2 +- tests/unit/async_/io/test_class_bolt5x0.py | 2 +- tests/unit/async_/io/test_class_bolt5x1.py | 2 +- tests/unit/async_/io/test_class_bolt5x2.py | 2 +- tests/unit/async_/io/test_class_bolt5x3.py | 2 +- tests/unit/async_/io/test_class_bolt5x4.py | 2 +- tests/unit/async_/io/test_class_bolt5x5.py | 2 +- tests/unit/async_/io/test_class_bolt5x6.py | 2 +- tests/unit/async_/io/test_class_bolt5x7.py | 2 +- tests/unit/async_/io/test_class_bolt5x8.py | 850 +++++++++++++++++++++ tests/unit/async_/work/test_result.py | 6 +- tests/unit/async_/work/test_transaction.py | 4 +- tests/unit/common/work/test_summary.py | 2 + tests/unit/sync/io/test_class_bolt.py | 10 +- tests/unit/sync/io/test_class_bolt3.py | 2 +- tests/unit/sync/io/test_class_bolt4x0.py | 2 +- tests/unit/sync/io/test_class_bolt4x1.py | 2 +- tests/unit/sync/io/test_class_bolt4x2.py | 2 +- tests/unit/sync/io/test_class_bolt4x3.py | 2 +- tests/unit/sync/io/test_class_bolt4x4.py | 2 +- tests/unit/sync/io/test_class_bolt5x0.py | 2 +- tests/unit/sync/io/test_class_bolt5x1.py | 2 +- tests/unit/sync/io/test_class_bolt5x2.py | 2 +- tests/unit/sync/io/test_class_bolt5x3.py | 2 +- tests/unit/sync/io/test_class_bolt5x4.py | 2 +- tests/unit/sync/io/test_class_bolt5x5.py | 2 +- tests/unit/sync/io/test_class_bolt5x6.py | 2 +- tests/unit/sync/io/test_class_bolt5x7.py | 2 +- tests/unit/sync/io/test_class_bolt5x8.py | 850 +++++++++++++++++++++ tests/unit/sync/work/test_result.py | 6 +- tests/unit/sync/work/test_transaction.py | 4 +- 57 files changed, 2191 insertions(+), 199 deletions(-) create mode 100644 tests/unit/async_/io/test_class_bolt5x8.py create mode 100644 tests/unit/sync/io/test_class_bolt5x8.py diff --git a/src/neo4j/_async/home_db_cache.py b/src/neo4j/_async/home_db_cache.py index 5e174094..63b6ef16 100644 --- a/src/neo4j/_async/home_db_cache.py +++ b/src/neo4j/_async/home_db_cache.py @@ -28,7 +28,7 @@ if t.TYPE_CHECKING: # TAuthKey = t.Tuple[t.Tuple[]] - TKey = str | tuple[tuple[str, t.Hashable], ...] | None + TKey = str | tuple[tuple[str, t.Hashable], ...] | tuple[None] TVal = tuple[float, str] @@ -39,17 +39,18 @@ class AsyncHomeDbCache: def __init__( self, - ttl: float = float("inf"), enabled: bool = True, + ttl: float = float("inf"), max_size: int | None = None, ) -> None: if math.isnan(ttl) or ttl <= 0: - raise ValueError("home db cache ttl must be greater 0") + raise ValueError(f"home db cache ttl must be greater 0, got {ttl}") self._enabled = enabled self._ttl = ttl self._cache: dict[TKey, TVal] = {} self._lock = AsyncCooperativeLock() self._last_clean = monotonic() + self._max_size = max_size def compute_key( self, @@ -57,14 +58,16 @@ def compute_key( auth: dict | None, ) -> TKey: if not self._enabled: - return None + return (None,) if imp_user is not None: return imp_user if auth is not None: return _hashable_dict(auth) - return None + return (None,) def get(self, key: TKey) -> str | None: + if not self._enabled: + return None with self._lock: val = self._cache.get(key) if val is None: @@ -79,6 +82,8 @@ def get(self, key: TKey) -> str | None: return val[1] def set(self, key: TKey, value: str | None) -> None: + if not self._enabled: + return with self._lock: if value is None: self._cache.pop(key, None) @@ -86,6 +91,8 @@ def set(self, key: TKey, value: str | None) -> None: self._cache[key] = (monotonic(), value) def clear(self) -> None: + if not self._enabled: + return with self._lock: self._cache = {} self._last_clean = monotonic() diff --git a/src/neo4j/_async/io/_bolt.py b/src/neo4j/_async/io/_bolt.py index d973b2ac..8819e89c 100644 --- a/src/neo4j/_async/io/_bolt.py +++ b/src/neo4j/_async/io/_bolt.py @@ -157,10 +157,7 @@ def __init__( ), self.PROTOCOL_VERSION, ) - # so far `connection.recv_timeout_seconds` is the only available - # configuration hint that exists. Therefore, all hints can be stored at - # connection level. This might change in the future. - self.configuration_hints = {} + self.connection_hints = {} self.patch = {} self.outbox = AsyncOutbox( self.socket, @@ -209,6 +206,10 @@ def _get_client_state_manager(self) -> ClientStateManagerBase: ... def connection_id(self): return self.server_info._metadata.get("connection_id", "") + @property + @abc.abstractmethod + def ssr_enabled(self) -> bool: ... + @property @abc.abstractmethod def supports_multiple_results(self): @@ -291,6 +292,7 @@ def protocol_handlers(cls, protocol_version=None): AsyncBolt5x5, AsyncBolt5x6, AsyncBolt5x7, + AsyncBolt5x8, ) handlers = { @@ -308,6 +310,7 @@ def protocol_handlers(cls, protocol_version=None): AsyncBolt5x5.PROTOCOL_VERSION: AsyncBolt5x5, AsyncBolt5x6.PROTOCOL_VERSION: AsyncBolt5x6, AsyncBolt5x7.PROTOCOL_VERSION: AsyncBolt5x7, + AsyncBolt5x8.PROTOCOL_VERSION: AsyncBolt5x8, } if protocol_version is None: @@ -444,7 +447,10 @@ async def open( # avoid new lines after imports for better readability and conciseness # fmt: off - if protocol_version == (5, 7): + if protocol_version == (5, 8): + from ._bolt5 import AsyncBolt5x8 + bolt_cls = AsyncBolt5x8 + elif protocol_version == (5, 7): from ._bolt5 import AsyncBolt5x7 bolt_cls = AsyncBolt5x7 elif protocol_version == (5, 6): diff --git a/src/neo4j/_async/io/_bolt3.py b/src/neo4j/_async/io/_bolt3.py index 08e75abb..2997296f 100644 --- a/src/neo4j/_async/io/_bolt3.py +++ b/src/neo4j/_async/io/_bolt3.py @@ -148,6 +148,8 @@ class AsyncBolt3(AsyncBolt): PROTOCOL_VERSION = Version(3, 0) + ssr_enabled = False + supports_multiple_results = False supports_multiple_databases = False diff --git a/src/neo4j/_async/io/_bolt4.py b/src/neo4j/_async/io/_bolt4.py index 202d5570..abc9d4cb 100644 --- a/src/neo4j/_async/io/_bolt4.py +++ b/src/neo4j/_async/io/_bolt4.py @@ -64,6 +64,8 @@ class AsyncBolt4x0(AsyncBolt): PROTOCOL_VERSION = Version(4, 0) + ssr_enabled = False + supports_multiple_results = True supports_multiple_databases = True @@ -614,10 +616,10 @@ async def hello(self, dehydration_hooks=None, hydration_hooks=None): ) def on_success(metadata): - self.configuration_hints.update(metadata.pop("hints", {})) + self.connection_hints.update(metadata.pop("hints", {})) self.server_info.update(metadata) - if "connection.recv_timeout_seconds" in self.configuration_hints: - recv_timeout = self.configuration_hints[ + if "connection.recv_timeout_seconds" in self.connection_hints: + recv_timeout = self.connection_hints[ "connection.recv_timeout_seconds" ] if isinstance(recv_timeout, int) and recv_timeout > 0: diff --git a/src/neo4j/_async/io/_bolt5.py b/src/neo4j/_async/io/_bolt5.py index 06336193..a6f9e469 100644 --- a/src/neo4j/_async/io/_bolt5.py +++ b/src/neo4j/_async/io/_bolt5.py @@ -107,6 +107,10 @@ def _on_client_state_change(self, old_state, new_state): def _get_client_state_manager(self) -> ClientStateManagerBase: return self._client_state_manager + @property + def ssr_enabled(self) -> bool: + return False + @property def is_reset(self): # We can't be sure of the server's state if there are still pending @@ -141,10 +145,10 @@ async def hello(self, dehydration_hooks=None, hydration_hooks=None): ) def on_success(metadata): - self.configuration_hints.update(metadata.pop("hints", {})) + self.connection_hints.update(metadata.pop("hints", {})) self.server_info.update(metadata) - if "connection.recv_timeout_seconds" in self.configuration_hints: - recv_timeout = self.configuration_hints[ + if "connection.recv_timeout_seconds" in self.connection_hints: + recv_timeout = self.connection_hints[ "connection.recv_timeout_seconds" ] if isinstance(recv_timeout, int) and recv_timeout > 0: @@ -615,10 +619,10 @@ async def hello(self, dehydration_hooks=None, hydration_hooks=None): ) def on_success(metadata): - self.configuration_hints.update(metadata.pop("hints", {})) + self.connection_hints.update(metadata.pop("hints", {})) self.server_info.update(metadata) - if "connection.recv_timeout_seconds" in self.configuration_hints: - recv_timeout = self.configuration_hints[ + if "connection.recv_timeout_seconds" in self.connection_hints: + recv_timeout = self.connection_hints[ "connection.recv_timeout_seconds" ] if isinstance(recv_timeout, int) and recv_timeout > 0: @@ -702,10 +706,10 @@ async def hello(self, dehydration_hooks=None, hydration_hooks=None): ) def on_success(metadata): - self.configuration_hints.update(metadata.pop("hints", {})) + self.connection_hints.update(metadata.pop("hints", {})) self.server_info.update(metadata) - if "connection.recv_timeout_seconds" in self.configuration_hints: - recv_timeout = self.configuration_hints[ + if "connection.recv_timeout_seconds" in self.connection_hints: + recv_timeout = self.connection_hints[ "connection.recv_timeout_seconds" ] if isinstance(recv_timeout, int) and recv_timeout > 0: @@ -883,7 +887,7 @@ def telemetry( hydration_hooks=None, **handlers, ) -> None: - if self.telemetry_disabled or not self.configuration_hints.get( + if self.telemetry_disabled or not self.connection_hints.get( "telemetry.enabled", False ): return @@ -1225,3 +1229,11 @@ async def _process_message(self, tag, fields): ) return len(details), 1 + + +class AsyncBolt5x8(AsyncBolt5x7): + PROTOCOL_VERSION = Version(5, 8) + + @property + def ssr_enabled(self) -> bool: + return self.connection_hints.get("ssr.enabled", False) is True diff --git a/src/neo4j/_async/io/_pool.py b/src/neo4j/_async/io/_pool.py index 38e5900b..fbcae249 100644 --- a/src/neo4j/_async/io/_pool.py +++ b/src/neo4j/_async/io/_pool.py @@ -80,6 +80,37 @@ class AcquireAuth: force_auth: bool = False +@dataclass +class ConnectionFeatureTracker: + feature_check: t.Callable[[AsyncBolt], bool] + with_feature: int = 0 + without_feature: int = 0 + + @property + def has_feature(self): + return self.with_feature > 0 and self.without_feature == 0 + + def add_connection(self, connection): + if self.feature_check(connection): + self.with_feature += 1 + else: + self.without_feature += 1 + + def remove_connection(self, connection): + if self.feature_check(connection): + if self.with_feature == 0: + raise ValueError( + "No connections to be removed from feature tracker" + ) + self.with_feature -= 1 + else: + if self.without_feature == 0: + raise ValueError( + "No connections to be removed from feature tracker" + ) + self.without_feature -= 1 + + class AsyncIOPool(abc.ABC): """A collection of connections to one or more server addresses.""" @@ -96,11 +127,18 @@ def __init__(self, opener, pool_config, workspace_config): self.lock = AsyncCooperativeRLock() self.cond = AsyncCondition(self.lock) self.home_db_cache = AsyncHomeDbCache(max_size=10_000) + self._ssr_feature_tracker = ConnectionFeatureTracker( + feature_check=lambda connection: connection.ssr_enabled + ) @property @abc.abstractmethod def is_direct_pool(self) -> bool: ... + @property + def ssr_enabled(self) -> bool: + return self._ssr_feature_tracker.has_feature + async def __aenter__(self): return self @@ -135,6 +173,20 @@ def _remove_connection(self, connection): # connection isn't in the pool anymore. with suppress(ValueError): self.connections.get(address, []).remove(connection) + self._ssr_feature_tracker.remove_connection(connection) + + def _add_connections(self, address, *connections): + with self.lock: + self.connections[address].extend(connections) + for connection in connections: + self._ssr_feature_tracker.add_connection(connection) + + def _remove_connections(self, address, *connections): + with self.lock: + existing_connections = self.connections.get(address, []) + for connection in connections: + existing_connections.remove(connection) + self._ssr_feature_tracker.remove_connection(connection) async def _acquire_from_pool_checked( self, address, health_check, deadline @@ -195,7 +247,7 @@ async def connection_creator(): with self.lock: self.connections_reservations[address] -= 1 released_reservation = True - self.connections[address].append(connection) + self._add_connections(address, connection) return connection finally: if not released_reservation: @@ -495,8 +547,7 @@ async def deactivate(self, address): # First remove all connections in question, then try to close them. # If closing of a connection fails, we will end up in this method # again. - for conn in closable_connections: - connections.remove(conn) + self._remove_connections(address, *closable_connections) if not self.connections[address]: del self.connections[address] @@ -542,6 +593,8 @@ async def close(self): for address in list(self.connections) for connection in self.connections.pop(address, ()) ] + for connection in connections: + self._ssr_feature_tracker.remove_connection(connection) await self._close_connections(connections) except TypeError: pass diff --git a/src/neo4j/_async/work/result.py b/src/neo4j/_async/work/result.py index 696358de..11d8b043 100644 --- a/src/neo4j/_async/work/result.py +++ b/src/neo4j/_async/work/result.py @@ -114,6 +114,7 @@ def __init__( warn_notification_severity, on_closed, on_error, + on_database, ) -> None: self._connection_cls = connection.__class__ self._connection = ConnectionErrorHandler( @@ -122,6 +123,7 @@ def __init__( self._hydration_scope = connection.new_hydration_scope() self._on_error = on_error self._on_closed = on_closed + self._on_database = on_database self._metadata: dict = {} self._address: Address = self._connection.unresolved_address self._keys: tuple[str, ...] = () @@ -197,7 +199,7 @@ async def _run( } self._database = db - def on_attached(metadata): + async def on_attached(metadata): self._metadata.update(metadata) # For auto-commit there is no qid and Bolt 3 does not support qid self._raw_qid = metadata.get("qid", -1) @@ -205,6 +207,9 @@ def on_attached(metadata): self._connection.most_recent_qid = self._raw_qid self._keys = metadata.get("fields") self._attached = True + db_ = metadata.get("db") + if isinstance(db_, str): + await AsyncUtil.callback(self._on_database, db_) async def on_failed_attach(metadata): self._metadata.update(metadata) diff --git a/src/neo4j/_async/work/session.py b/src/neo4j/_async/work/session.py index 77c0a7bf..264b31f3 100644 --- a/src/neo4j/_async/work/session.py +++ b/src/neo4j/_async/work/session.py @@ -321,6 +321,7 @@ async def run( self._config.warn_notification_severity, self._result_closed, self._result_error, + self._make_query_database_resolution_callback(), ) bookmarks = await self._get_bookmarks() parameters = dict(parameters or {}, **kwargs) @@ -448,6 +449,7 @@ async def _open_transaction( self._transaction_closed_handler, self._transaction_error_handler, self._transaction_cancel_handler, + self._make_query_database_resolution_callback(), ) bookmarks = await self._get_bookmarks() await self._transaction._begin( diff --git a/src/neo4j/_async/work/transaction.py b/src/neo4j/_async/work/transaction.py index 2a1aa062..4ce9937f 100644 --- a/src/neo4j/_async/work/transaction.py +++ b/src/neo4j/_async/work/transaction.py @@ -47,6 +47,7 @@ def __init__( on_closed, on_error, on_cancel, + on_database, ): self._connection = connection self._error_handling_connection = ConnectionErrorHandler( @@ -62,6 +63,7 @@ def __init__( self._on_closed = on_closed self._on_error = on_error self._on_cancel = on_cancel + self._on_database = on_database super().__init__() async def _enter(self) -> te.Self: @@ -92,6 +94,11 @@ async def _begin( notifications_disabled_classifications, pipelined=False, ): + async def on_begin_success(metadata_): + db = metadata_.get("db") + if isinstance(db, str): + await AsyncUtil.callback(self._on_database, db) + self._database = database self._connection.begin( bookmarks=bookmarks, @@ -102,6 +109,7 @@ async def _begin( imp_user=imp_user, notifications_min_severity=notifications_min_severity, notifications_disabled_classifications=notifications_disabled_classifications, + on_success=on_begin_success, ) if not pipelined: await self._error_handling_connection.send_all() @@ -188,6 +196,7 @@ async def run( self._warn_notification_severity, self._result_on_closed_handler, self._error_handler, + None, ) self._results.append(result) diff --git a/src/neo4j/_async/work/workspace.py b/src/neo4j/_async/work/workspace.py index 7424a3cc..22ecf1ed 100644 --- a/src/neo4j/_async/work/workspace.py +++ b/src/neo4j/_async/work/workspace.py @@ -18,6 +18,7 @@ import logging import typing as t +from dataclasses import dataclass from ..._async_compat.util import AsyncUtil from ..._auth_management import to_auth_dict @@ -56,6 +57,12 @@ log = logging.getLogger("neo4j") +@dataclass +class _TargetDatabase: + database: str | None + from_cache: bool = False + + class AsyncWorkspace(AsyncNonConcurrentMethodChecker): def __init__(self, pool, config): assert isinstance(config, WorkspaceConfig) @@ -63,8 +70,9 @@ def __init__(self, pool, config): self._config = config self._connection = None self._connection_access_mode = None + self._last_cache_key: TKey | None = None # Sessions are supposed to cache the database on which to operate. - self._cached_database = False + self._pinned_database = False self._bookmarks = () self._initial_bookmarks = () self._bookmark_manager = None @@ -103,20 +111,33 @@ async def __aenter__(self) -> AsyncWorkspace: async def __aexit__(self, exc_type, exc_value, traceback): await self.close() - def _make_database_callback( + def _make_routing_database_callback( self, cache_key: TKey, ) -> t.Callable[[str], None]: def _database_callback(database: str | None) -> None: + if not self._pinned_database: + self._set_pinned_database(database) db_cache: AsyncHomeDbCache = self._pool.home_db_cache - if db_cache.enabled: - db_cache.set(cache_key, database) - self._set_cached_database(database) + db_cache.set(cache_key, database) return _database_callback - def _set_cached_database(self, database): - self._cached_database = True + def _make_query_database_resolution_callback( + self, + ) -> t.Callable[[str], None] | None: + def _database_callback(database: str | None) -> None: + if not self._pinned_database: + self._set_pinned_database(database) + if self._last_cache_key is None: + return + db_cache: AsyncHomeDbCache = self._pool.home_db_cache + db_cache.set(self._last_cache_key, database) + + return _database_callback + + def _set_pinned_database(self, database): + self._pinned_database = True self._config.database = database def _initialize_bookmarks(self, bookmarks): @@ -166,7 +187,7 @@ async def _update_bookmark(self, bookmark): return await self._update_bookmarks((bookmark,)) - async def _connect(self, access_mode, auth=None, **acquire_kwargs): + async def _connect(self, access_mode, auth=None, **acquire_kwargs) -> None: acquisition_timeout = self._config.connection_acquisition_timeout force_auth = acquire_kwargs.pop("force_auth", False) acquire_auth = AcquireAuth(auth, force_auth=force_auth) @@ -177,63 +198,82 @@ async def _connect(self, access_mode, auth=None, **acquire_kwargs): await self._connection.send_all() await self._connection.fetch_all() await self._disconnect() - await self._fill_cached_database(acquire_auth) + + ssr_enabled = self._pool.ssr_enabled + routing_target = await self._get_routing_target_database( + acquire_auth, ssr_enabled=ssr_enabled + ) acquire_kwargs_ = { "access_mode": access_mode, "timeout": acquisition_timeout, - "database": self._config.database, + "database": routing_target.database, "bookmarks": await self._get_bookmarks(), "auth": acquire_auth, "liveness_check_timeout": None, } acquire_kwargs_.update(acquire_kwargs) self._connection = await self._pool.acquire(**acquire_kwargs_) + if routing_target.from_cache and ( + not self._pool.ssr_enabled or not self._connection.ssr_enabled + ): + # race condition: in the meantime, the pool added a connection, + # which does not support SSR. + # => we need to fall back to explicit home database resolution + await self._disconnect() + routing_target = await self._get_routing_target_database( + acquire_auth, ssr_enabled=False + ) + acquire_kwargs_["database"] = routing_target.database + self._connection = await self._pool.acquire(**acquire_kwargs_) self._connection_access_mode = access_mode - async def _fill_cached_database(self, acquire_auth: AcquireAuth) -> None: + async def _get_routing_target_database( + self, + acquire_auth: AcquireAuth, + ssr_enabled: bool, + ) -> _TargetDatabase: + if self._config.database is not None or not isinstance( + self._pool, AsyncNeo4jPool + ): + self._set_pinned_database(self._config.database) + log.debug( + "[#0000] _: routing towards fixed database: %s", + self._config.database, + ) + return _TargetDatabase(self._config.database) + auth = acquire_auth.auth - acquisition_timeout = self._config.connection_acquisition_timeout - if not self._cached_database: - if self._config.database is not None or not isinstance( - self._pool, AsyncNeo4jPool - ): - self._set_cached_database(self._config.database) - else: - # This is the first time we open a connection to a server in a - # cluster environment for this session without explicitly - # configured database. Hence, we request a routing table update - # to try to fetch the home database. If provided by the server, - # we shall use this database explicitly for all subsequent - # actions within this session. - # Unless we have the resolved home db in out cache: - - db_cache: AsyncHomeDbCache = self._pool.home_db_cache - cache_key = cached_db = None - if db_cache.enabled: - cache_key = db_cache.compute_key( - self._config.impersonated_user, - await self._resolve_session_auth(auth), - ) - cached_db = db_cache.get(cache_key) - if cached_db is not None: - log.debug( - ( - "[#0000] _: resolved home database " - "from cache: %s" - ), - cached_db, - ) - self._set_cached_database(cached_db) - return - log.debug("[#0000] _: resolve home database") - await self._pool.update_routing_table( - database=self._config.database, - imp_user=self._config.impersonated_user, - bookmarks=await self._get_bookmarks(), - auth=acquire_auth, - acquisition_timeout=acquisition_timeout, - database_callback=self._make_database_callback(cache_key), + resolved_auth = await self._resolve_session_auth(auth) + db_cache: AsyncHomeDbCache = self._pool.home_db_cache + cache_key = db_cache.compute_key( + self._config.impersonated_user, + resolved_auth, + ) + self._last_cache_key = cache_key + + if ssr_enabled: + cached_db = db_cache.get(cache_key) + if cached_db is not None: + log.debug( + ( + "[#0000] _: routing towards cached " + "database: %s" + ), + cached_db, ) + return _TargetDatabase(cached_db, from_cache=True) + + acquisition_timeout = self._config.connection_acquisition_timeout + log.debug("[#0000] _: resolve home database") + await self._pool.update_routing_table( + database=self._config.database, + imp_user=self._config.impersonated_user, + bookmarks=await self._get_bookmarks(), + auth=acquire_auth, + acquisition_timeout=acquisition_timeout, + database_callback=self._make_routing_database_callback(cache_key), + ) + return _TargetDatabase(self._config.database) @staticmethod async def _resolve_session_auth( @@ -254,6 +294,7 @@ async def _resolve_session_auth( return to_auth_dict(resolved_auth) async def _disconnect(self, sync=False): + self._last_cache_key = None if self._connection: if sync: try: diff --git a/src/neo4j/_sync/home_db_cache.py b/src/neo4j/_sync/home_db_cache.py index 8f8dd21f..e56d052f 100644 --- a/src/neo4j/_sync/home_db_cache.py +++ b/src/neo4j/_sync/home_db_cache.py @@ -28,7 +28,7 @@ if t.TYPE_CHECKING: # TAuthKey = t.Tuple[t.Tuple[]] - TKey = str | tuple[tuple[str, t.Hashable], ...] | None + TKey = str | tuple[tuple[str, t.Hashable], ...] | tuple[None] TVal = tuple[float, str] @@ -39,17 +39,18 @@ class HomeDbCache: def __init__( self, - ttl: float = float("inf"), enabled: bool = True, + ttl: float = float("inf"), max_size: int | None = None, ) -> None: if math.isnan(ttl) or ttl <= 0: - raise ValueError("home db cache ttl must be greater 0") + raise ValueError(f"home db cache ttl must be greater 0, got {ttl}") self._enabled = enabled self._ttl = ttl self._cache: dict[TKey, TVal] = {} self._lock = CooperativeLock() self._last_clean = monotonic() + self._max_size = max_size def compute_key( self, @@ -57,14 +58,16 @@ def compute_key( auth: dict | None, ) -> TKey: if not self._enabled: - return None + return (None,) if imp_user is not None: return imp_user if auth is not None: return _hashable_dict(auth) - return None + return (None,) def get(self, key: TKey) -> str | None: + if not self._enabled: + return None with self._lock: val = self._cache.get(key) if val is None: @@ -79,6 +82,8 @@ def get(self, key: TKey) -> str | None: return val[1] def set(self, key: TKey, value: str | None) -> None: + if not self._enabled: + return with self._lock: if value is None: self._cache.pop(key, None) @@ -86,6 +91,8 @@ def set(self, key: TKey, value: str | None) -> None: self._cache[key] = (monotonic(), value) def clear(self) -> None: + if not self._enabled: + return with self._lock: self._cache = {} self._last_clean = monotonic() diff --git a/src/neo4j/_sync/io/_bolt.py b/src/neo4j/_sync/io/_bolt.py index 137aa428..223f6f85 100644 --- a/src/neo4j/_sync/io/_bolt.py +++ b/src/neo4j/_sync/io/_bolt.py @@ -157,10 +157,7 @@ def __init__( ), self.PROTOCOL_VERSION, ) - # so far `connection.recv_timeout_seconds` is the only available - # configuration hint that exists. Therefore, all hints can be stored at - # connection level. This might change in the future. - self.configuration_hints = {} + self.connection_hints = {} self.patch = {} self.outbox = Outbox( self.socket, @@ -209,6 +206,10 @@ def _get_client_state_manager(self) -> ClientStateManagerBase: ... def connection_id(self): return self.server_info._metadata.get("connection_id", "") + @property + @abc.abstractmethod + def ssr_enabled(self) -> bool: ... + @property @abc.abstractmethod def supports_multiple_results(self): @@ -291,6 +292,7 @@ def protocol_handlers(cls, protocol_version=None): Bolt5x5, Bolt5x6, Bolt5x7, + Bolt5x8, ) handlers = { @@ -308,6 +310,7 @@ def protocol_handlers(cls, protocol_version=None): Bolt5x5.PROTOCOL_VERSION: Bolt5x5, Bolt5x6.PROTOCOL_VERSION: Bolt5x6, Bolt5x7.PROTOCOL_VERSION: Bolt5x7, + Bolt5x8.PROTOCOL_VERSION: Bolt5x8, } if protocol_version is None: @@ -444,7 +447,10 @@ def open( # avoid new lines after imports for better readability and conciseness # fmt: off - if protocol_version == (5, 7): + if protocol_version == (5, 8): + from ._bolt5 import Bolt5x8 + bolt_cls = Bolt5x8 + elif protocol_version == (5, 7): from ._bolt5 import Bolt5x7 bolt_cls = Bolt5x7 elif protocol_version == (5, 6): diff --git a/src/neo4j/_sync/io/_bolt3.py b/src/neo4j/_sync/io/_bolt3.py index e3cfd142..3f4c93a3 100644 --- a/src/neo4j/_sync/io/_bolt3.py +++ b/src/neo4j/_sync/io/_bolt3.py @@ -148,6 +148,8 @@ class Bolt3(Bolt): PROTOCOL_VERSION = Version(3, 0) + ssr_enabled = False + supports_multiple_results = False supports_multiple_databases = False diff --git a/src/neo4j/_sync/io/_bolt4.py b/src/neo4j/_sync/io/_bolt4.py index 69bb6dd6..99c04185 100644 --- a/src/neo4j/_sync/io/_bolt4.py +++ b/src/neo4j/_sync/io/_bolt4.py @@ -64,6 +64,8 @@ class Bolt4x0(Bolt): PROTOCOL_VERSION = Version(4, 0) + ssr_enabled = False + supports_multiple_results = True supports_multiple_databases = True @@ -614,10 +616,10 @@ def hello(self, dehydration_hooks=None, hydration_hooks=None): ) def on_success(metadata): - self.configuration_hints.update(metadata.pop("hints", {})) + self.connection_hints.update(metadata.pop("hints", {})) self.server_info.update(metadata) - if "connection.recv_timeout_seconds" in self.configuration_hints: - recv_timeout = self.configuration_hints[ + if "connection.recv_timeout_seconds" in self.connection_hints: + recv_timeout = self.connection_hints[ "connection.recv_timeout_seconds" ] if isinstance(recv_timeout, int) and recv_timeout > 0: diff --git a/src/neo4j/_sync/io/_bolt5.py b/src/neo4j/_sync/io/_bolt5.py index 4138a9d5..27e0b695 100644 --- a/src/neo4j/_sync/io/_bolt5.py +++ b/src/neo4j/_sync/io/_bolt5.py @@ -107,6 +107,10 @@ def _on_client_state_change(self, old_state, new_state): def _get_client_state_manager(self) -> ClientStateManagerBase: return self._client_state_manager + @property + def ssr_enabled(self) -> bool: + return False + @property def is_reset(self): # We can't be sure of the server's state if there are still pending @@ -141,10 +145,10 @@ def hello(self, dehydration_hooks=None, hydration_hooks=None): ) def on_success(metadata): - self.configuration_hints.update(metadata.pop("hints", {})) + self.connection_hints.update(metadata.pop("hints", {})) self.server_info.update(metadata) - if "connection.recv_timeout_seconds" in self.configuration_hints: - recv_timeout = self.configuration_hints[ + if "connection.recv_timeout_seconds" in self.connection_hints: + recv_timeout = self.connection_hints[ "connection.recv_timeout_seconds" ] if isinstance(recv_timeout, int) and recv_timeout > 0: @@ -615,10 +619,10 @@ def hello(self, dehydration_hooks=None, hydration_hooks=None): ) def on_success(metadata): - self.configuration_hints.update(metadata.pop("hints", {})) + self.connection_hints.update(metadata.pop("hints", {})) self.server_info.update(metadata) - if "connection.recv_timeout_seconds" in self.configuration_hints: - recv_timeout = self.configuration_hints[ + if "connection.recv_timeout_seconds" in self.connection_hints: + recv_timeout = self.connection_hints[ "connection.recv_timeout_seconds" ] if isinstance(recv_timeout, int) and recv_timeout > 0: @@ -702,10 +706,10 @@ def hello(self, dehydration_hooks=None, hydration_hooks=None): ) def on_success(metadata): - self.configuration_hints.update(metadata.pop("hints", {})) + self.connection_hints.update(metadata.pop("hints", {})) self.server_info.update(metadata) - if "connection.recv_timeout_seconds" in self.configuration_hints: - recv_timeout = self.configuration_hints[ + if "connection.recv_timeout_seconds" in self.connection_hints: + recv_timeout = self.connection_hints[ "connection.recv_timeout_seconds" ] if isinstance(recv_timeout, int) and recv_timeout > 0: @@ -883,7 +887,7 @@ def telemetry( hydration_hooks=None, **handlers, ) -> None: - if self.telemetry_disabled or not self.configuration_hints.get( + if self.telemetry_disabled or not self.connection_hints.get( "telemetry.enabled", False ): return @@ -1225,3 +1229,11 @@ def _process_message(self, tag, fields): ) return len(details), 1 + + +class Bolt5x8(Bolt5x7): + PROTOCOL_VERSION = Version(5, 8) + + @property + def ssr_enabled(self) -> bool: + return self.connection_hints.get("ssr.enabled", False) is True diff --git a/src/neo4j/_sync/io/_pool.py b/src/neo4j/_sync/io/_pool.py index 8ea85586..e5b406cf 100644 --- a/src/neo4j/_sync/io/_pool.py +++ b/src/neo4j/_sync/io/_pool.py @@ -77,6 +77,37 @@ class AcquireAuth: force_auth: bool = False +@dataclass +class ConnectionFeatureTracker: + feature_check: t.Callable[[Bolt], bool] + with_feature: int = 0 + without_feature: int = 0 + + @property + def has_feature(self): + return self.with_feature > 0 and self.without_feature == 0 + + def add_connection(self, connection): + if self.feature_check(connection): + self.with_feature += 1 + else: + self.without_feature += 1 + + def remove_connection(self, connection): + if self.feature_check(connection): + if self.with_feature == 0: + raise ValueError( + "No connections to be removed from feature tracker" + ) + self.with_feature -= 1 + else: + if self.without_feature == 0: + raise ValueError( + "No connections to be removed from feature tracker" + ) + self.without_feature -= 1 + + class IOPool(abc.ABC): """A collection of connections to one or more server addresses.""" @@ -93,11 +124,18 @@ def __init__(self, opener, pool_config, workspace_config): self.lock = CooperativeRLock() self.cond = Condition(self.lock) self.home_db_cache = HomeDbCache(max_size=10_000) + self._ssr_feature_tracker = ConnectionFeatureTracker( + feature_check=lambda connection: connection.ssr_enabled + ) @property @abc.abstractmethod def is_direct_pool(self) -> bool: ... + @property + def ssr_enabled(self) -> bool: + return self._ssr_feature_tracker.has_feature + def __enter__(self): return self @@ -132,6 +170,20 @@ def _remove_connection(self, connection): # connection isn't in the pool anymore. with suppress(ValueError): self.connections.get(address, []).remove(connection) + self._ssr_feature_tracker.remove_connection(connection) + + def _add_connections(self, address, *connections): + with self.lock: + self.connections[address].extend(connections) + for connection in connections: + self._ssr_feature_tracker.add_connection(connection) + + def _remove_connections(self, address, *connections): + with self.lock: + existing_connections = self.connections.get(address, []) + for connection in connections: + existing_connections.remove(connection) + self._ssr_feature_tracker.remove_connection(connection) def _acquire_from_pool_checked( self, address, health_check, deadline @@ -192,7 +244,7 @@ def connection_creator(): with self.lock: self.connections_reservations[address] -= 1 released_reservation = True - self.connections[address].append(connection) + self._add_connections(address, connection) return connection finally: if not released_reservation: @@ -492,8 +544,7 @@ def deactivate(self, address): # First remove all connections in question, then try to close them. # If closing of a connection fails, we will end up in this method # again. - for conn in closable_connections: - connections.remove(conn) + self._remove_connections(address, *closable_connections) if not self.connections[address]: del self.connections[address] @@ -539,6 +590,8 @@ def close(self): for address in list(self.connections) for connection in self.connections.pop(address, ()) ] + for connection in connections: + self._ssr_feature_tracker.remove_connection(connection) self._close_connections(connections) except TypeError: pass diff --git a/src/neo4j/_sync/work/result.py b/src/neo4j/_sync/work/result.py index f343efa2..d2e1060e 100644 --- a/src/neo4j/_sync/work/result.py +++ b/src/neo4j/_sync/work/result.py @@ -114,6 +114,7 @@ def __init__( warn_notification_severity, on_closed, on_error, + on_database, ) -> None: self._connection_cls = connection.__class__ self._connection = ConnectionErrorHandler( @@ -122,6 +123,7 @@ def __init__( self._hydration_scope = connection.new_hydration_scope() self._on_error = on_error self._on_closed = on_closed + self._on_database = on_database self._metadata: dict = {} self._address: Address = self._connection.unresolved_address self._keys: tuple[str, ...] = () @@ -205,6 +207,9 @@ def on_attached(metadata): self._connection.most_recent_qid = self._raw_qid self._keys = metadata.get("fields") self._attached = True + db_ = metadata.get("db") + if isinstance(db_, str): + Util.callback(self._on_database, db_) def on_failed_attach(metadata): self._metadata.update(metadata) diff --git a/src/neo4j/_sync/work/session.py b/src/neo4j/_sync/work/session.py index 61bd23b8..99f77a4f 100644 --- a/src/neo4j/_sync/work/session.py +++ b/src/neo4j/_sync/work/session.py @@ -321,6 +321,7 @@ def run( self._config.warn_notification_severity, self._result_closed, self._result_error, + self._make_query_database_resolution_callback(), ) bookmarks = self._get_bookmarks() parameters = dict(parameters or {}, **kwargs) @@ -448,6 +449,7 @@ def _open_transaction( self._transaction_closed_handler, self._transaction_error_handler, self._transaction_cancel_handler, + self._make_query_database_resolution_callback(), ) bookmarks = self._get_bookmarks() self._transaction._begin( diff --git a/src/neo4j/_sync/work/transaction.py b/src/neo4j/_sync/work/transaction.py index f1625a24..c0a270f9 100644 --- a/src/neo4j/_sync/work/transaction.py +++ b/src/neo4j/_sync/work/transaction.py @@ -47,6 +47,7 @@ def __init__( on_closed, on_error, on_cancel, + on_database, ): self._connection = connection self._error_handling_connection = ConnectionErrorHandler( @@ -62,6 +63,7 @@ def __init__( self._on_closed = on_closed self._on_error = on_error self._on_cancel = on_cancel + self._on_database = on_database super().__init__() def _enter(self) -> te.Self: @@ -92,6 +94,11 @@ def _begin( notifications_disabled_classifications, pipelined=False, ): + def on_begin_success(metadata_): + db = metadata_.get("db") + if isinstance(db, str): + Util.callback(self._on_database, db) + self._database = database self._connection.begin( bookmarks=bookmarks, @@ -102,6 +109,7 @@ def _begin( imp_user=imp_user, notifications_min_severity=notifications_min_severity, notifications_disabled_classifications=notifications_disabled_classifications, + on_success=on_begin_success, ) if not pipelined: self._error_handling_connection.send_all() @@ -188,6 +196,7 @@ def run( self._warn_notification_severity, self._result_on_closed_handler, self._error_handler, + None, ) self._results.append(result) diff --git a/src/neo4j/_sync/work/workspace.py b/src/neo4j/_sync/work/workspace.py index 7a30cb65..e13ba382 100644 --- a/src/neo4j/_sync/work/workspace.py +++ b/src/neo4j/_sync/work/workspace.py @@ -18,6 +18,7 @@ import logging import typing as t +from dataclasses import dataclass from ..._async_compat.util import Util from ..._auth_management import to_auth_dict @@ -53,6 +54,12 @@ log = logging.getLogger("neo4j") +@dataclass +class _TargetDatabase: + database: str | None + from_cache: bool = False + + class Workspace(NonConcurrentMethodChecker): def __init__(self, pool, config): assert isinstance(config, WorkspaceConfig) @@ -60,8 +67,9 @@ def __init__(self, pool, config): self._config = config self._connection = None self._connection_access_mode = None + self._last_cache_key: TKey | None = None # Sessions are supposed to cache the database on which to operate. - self._cached_database = False + self._pinned_database = False self._bookmarks = () self._initial_bookmarks = () self._bookmark_manager = None @@ -100,20 +108,33 @@ def __enter__(self) -> Workspace: def __exit__(self, exc_type, exc_value, traceback): self.close() - def _make_database_callback( + def _make_routing_database_callback( self, cache_key: TKey, ) -> t.Callable[[str], None]: def _database_callback(database: str | None) -> None: + if not self._pinned_database: + self._set_pinned_database(database) db_cache: HomeDbCache = self._pool.home_db_cache - if db_cache.enabled: - db_cache.set(cache_key, database) - self._set_cached_database(database) + db_cache.set(cache_key, database) return _database_callback - def _set_cached_database(self, database): - self._cached_database = True + def _make_query_database_resolution_callback( + self, + ) -> t.Callable[[str], None] | None: + def _database_callback(database: str | None) -> None: + if not self._pinned_database: + self._set_pinned_database(database) + if self._last_cache_key is None: + return + db_cache: HomeDbCache = self._pool.home_db_cache + db_cache.set(self._last_cache_key, database) + + return _database_callback + + def _set_pinned_database(self, database): + self._pinned_database = True self._config.database = database def _initialize_bookmarks(self, bookmarks): @@ -163,7 +184,7 @@ def _update_bookmark(self, bookmark): return self._update_bookmarks((bookmark,)) - def _connect(self, access_mode, auth=None, **acquire_kwargs): + def _connect(self, access_mode, auth=None, **acquire_kwargs) -> None: acquisition_timeout = self._config.connection_acquisition_timeout force_auth = acquire_kwargs.pop("force_auth", False) acquire_auth = AcquireAuth(auth, force_auth=force_auth) @@ -174,63 +195,82 @@ def _connect(self, access_mode, auth=None, **acquire_kwargs): self._connection.send_all() self._connection.fetch_all() self._disconnect() - self._fill_cached_database(acquire_auth) + + ssr_enabled = self._pool.ssr_enabled + routing_target = self._get_routing_target_database( + acquire_auth, ssr_enabled=ssr_enabled + ) acquire_kwargs_ = { "access_mode": access_mode, "timeout": acquisition_timeout, - "database": self._config.database, + "database": routing_target.database, "bookmarks": self._get_bookmarks(), "auth": acquire_auth, "liveness_check_timeout": None, } acquire_kwargs_.update(acquire_kwargs) self._connection = self._pool.acquire(**acquire_kwargs_) + if routing_target.from_cache and ( + not self._pool.ssr_enabled or not self._connection.ssr_enabled + ): + # race condition: in the meantime, the pool added a connection, + # which does not support SSR. + # => we need to fall back to explicit home database resolution + self._disconnect() + routing_target = self._get_routing_target_database( + acquire_auth, ssr_enabled=False + ) + acquire_kwargs_["database"] = routing_target.database + self._connection = self._pool.acquire(**acquire_kwargs_) self._connection_access_mode = access_mode - def _fill_cached_database(self, acquire_auth: AcquireAuth) -> None: + def _get_routing_target_database( + self, + acquire_auth: AcquireAuth, + ssr_enabled: bool, + ) -> _TargetDatabase: + if self._config.database is not None or not isinstance( + self._pool, Neo4jPool + ): + self._set_pinned_database(self._config.database) + log.debug( + "[#0000] _: routing towards fixed database: %s", + self._config.database, + ) + return _TargetDatabase(self._config.database) + auth = acquire_auth.auth - acquisition_timeout = self._config.connection_acquisition_timeout - if not self._cached_database: - if self._config.database is not None or not isinstance( - self._pool, Neo4jPool - ): - self._set_cached_database(self._config.database) - else: - # This is the first time we open a connection to a server in a - # cluster environment for this session without explicitly - # configured database. Hence, we request a routing table update - # to try to fetch the home database. If provided by the server, - # we shall use this database explicitly for all subsequent - # actions within this session. - # Unless we have the resolved home db in out cache: - - db_cache: HomeDbCache = self._pool.home_db_cache - cache_key = cached_db = None - if db_cache.enabled: - cache_key = db_cache.compute_key( - self._config.impersonated_user, - self._resolve_session_auth(auth), - ) - cached_db = db_cache.get(cache_key) - if cached_db is not None: - log.debug( - ( - "[#0000] _: resolved home database " - "from cache: %s" - ), - cached_db, - ) - self._set_cached_database(cached_db) - return - log.debug("[#0000] _: resolve home database") - self._pool.update_routing_table( - database=self._config.database, - imp_user=self._config.impersonated_user, - bookmarks=self._get_bookmarks(), - auth=acquire_auth, - acquisition_timeout=acquisition_timeout, - database_callback=self._make_database_callback(cache_key), + resolved_auth = self._resolve_session_auth(auth) + db_cache: HomeDbCache = self._pool.home_db_cache + cache_key = db_cache.compute_key( + self._config.impersonated_user, + resolved_auth, + ) + self._last_cache_key = cache_key + + if ssr_enabled: + cached_db = db_cache.get(cache_key) + if cached_db is not None: + log.debug( + ( + "[#0000] _: routing towards cached " + "database: %s" + ), + cached_db, ) + return _TargetDatabase(cached_db, from_cache=True) + + acquisition_timeout = self._config.connection_acquisition_timeout + log.debug("[#0000] _: resolve home database") + self._pool.update_routing_table( + database=self._config.database, + imp_user=self._config.impersonated_user, + bookmarks=self._get_bookmarks(), + auth=acquire_auth, + acquisition_timeout=acquisition_timeout, + database_callback=self._make_routing_database_callback(cache_key), + ) + return _TargetDatabase(self._config.database) @staticmethod def _resolve_session_auth( @@ -251,6 +291,7 @@ def _resolve_session_auth( return to_auth_dict(resolved_auth) def _disconnect(self, sync=False): + self._last_cache_key = None if self._connection: if sync: try: diff --git a/tests/unit/async_/io/test_class_bolt.py b/tests/unit/async_/io/test_class_bolt.py index b0ddbc96..4469b045 100644 --- a/tests/unit/async_/io/test_class_bolt.py +++ b/tests/unit/async_/io/test_class_bolt.py @@ -38,7 +38,7 @@ def test_class_method_protocol_handlers(): expected_handlers = { (3, 0), (4, 1), (4, 2), (4, 3), (4, 4), - (5, 0), (5, 1), (5, 2), (5, 3), (5, 4), (5, 5), (5, 6), (5, 7), + (5, 0), (5, 1), (5, 2), (5, 3), (5, 4), (5, 5), (5, 6), (5, 7), (5, 8), } # fmt: on @@ -69,7 +69,8 @@ def test_class_method_protocol_handlers(): ((5, 5), 1), ((5, 6), 1), ((5, 7), 1), - ((5, 8), 0), + ((5, 8), 1), + ((5, 9), 0), ((6, 0), 0), ], ) @@ -92,7 +93,7 @@ def test_class_method_get_handshake(): handshake = AsyncBolt.get_handshake() assert ( handshake - == b"\x00\x07\x07\x05\x00\x02\x04\x04\x00\x00\x01\x04\x00\x00\x00\x03" + == b"\x00\x08\x08\x05\x00\x02\x04\x04\x00\x00\x01\x04\x00\x00\x00\x03" ) @@ -143,6 +144,7 @@ async def test_cancel_hello_in_open(mocker, none_auth): ((5, 5), "neo4j._async.io._bolt5.AsyncBolt5x5"), ((5, 6), "neo4j._async.io._bolt5.AsyncBolt5x6"), ((5, 7), "neo4j._async.io._bolt5.AsyncBolt5x7"), + ((5, 8), "neo4j._async.io._bolt5.AsyncBolt5x8"), ), ) @mark_async_test @@ -181,7 +183,7 @@ async def test_version_negotiation( (2, 0), (4, 0), (3, 1), - (5, 8), + (5, 9), (6, 0), ), ) diff --git a/tests/unit/async_/io/test_class_bolt3.py b/tests/unit/async_/io/test_class_bolt3.py index e2f56ff9..6442e8f9 100644 --- a/tests/unit/async_/io/test_class_bolt3.py +++ b/tests/unit/async_/io/test_class_bolt3.py @@ -129,7 +129,7 @@ async def test_telemetry_message( telemetry_disabled=driver_disabled, ) if serv_enabled: - connection.configuration_hints["telemetry.enabled"] = True + connection.connection_hints["telemetry.enabled"] = True connection.telemetry(api) await connection.send_all() diff --git a/tests/unit/async_/io/test_class_bolt4x0.py b/tests/unit/async_/io/test_class_bolt4x0.py index fa555fd1..771f7e30 100644 --- a/tests/unit/async_/io/test_class_bolt4x0.py +++ b/tests/unit/async_/io/test_class_bolt4x0.py @@ -232,7 +232,7 @@ async def test_telemetry_message( telemetry_disabled=driver_disabled, ) if serv_enabled: - connection.configuration_hints["telemetry.enabled"] = True + connection.connection_hints["telemetry.enabled"] = True connection.telemetry(api) await connection.send_all() diff --git a/tests/unit/async_/io/test_class_bolt4x1.py b/tests/unit/async_/io/test_class_bolt4x1.py index e7ca17e0..cd37fce4 100644 --- a/tests/unit/async_/io/test_class_bolt4x1.py +++ b/tests/unit/async_/io/test_class_bolt4x1.py @@ -254,7 +254,7 @@ async def test_telemetry_message( telemetry_disabled=driver_disabled, ) if serv_enabled: - connection.configuration_hints["telemetry.enabled"] = True + connection.connection_hints["telemetry.enabled"] = True connection.telemetry(api) await connection.send_all() diff --git a/tests/unit/async_/io/test_class_bolt4x2.py b/tests/unit/async_/io/test_class_bolt4x2.py index bffb4424..54180fa0 100644 --- a/tests/unit/async_/io/test_class_bolt4x2.py +++ b/tests/unit/async_/io/test_class_bolt4x2.py @@ -254,7 +254,7 @@ async def test_telemetry_message( telemetry_disabled=driver_disabled, ) if serv_enabled: - connection.configuration_hints["telemetry.enabled"] = True + connection.connection_hints["telemetry.enabled"] = True connection.telemetry(api) await connection.send_all() diff --git a/tests/unit/async_/io/test_class_bolt4x3.py b/tests/unit/async_/io/test_class_bolt4x3.py index 1f249feb..ba39bfbb 100644 --- a/tests/unit/async_/io/test_class_bolt4x3.py +++ b/tests/unit/async_/io/test_class_bolt4x3.py @@ -254,7 +254,7 @@ async def test_telemetry_message( telemetry_disabled=driver_disabled, ) if serv_enabled: - connection.configuration_hints["telemetry.enabled"] = True + connection.connection_hints["telemetry.enabled"] = True connection.telemetry(api) await connection.send_all() diff --git a/tests/unit/async_/io/test_class_bolt4x4.py b/tests/unit/async_/io/test_class_bolt4x4.py index 695ac7c9..b66f8b32 100644 --- a/tests/unit/async_/io/test_class_bolt4x4.py +++ b/tests/unit/async_/io/test_class_bolt4x4.py @@ -278,7 +278,7 @@ async def test_telemetry_message( telemetry_disabled=driver_disabled, ) if serv_enabled: - connection.configuration_hints["telemetry.enabled"] = True + connection.connection_hints["telemetry.enabled"] = True connection.telemetry(api) await connection.send_all() diff --git a/tests/unit/async_/io/test_class_bolt5x0.py b/tests/unit/async_/io/test_class_bolt5x0.py index d1f09dcc..823d9ddb 100644 --- a/tests/unit/async_/io/test_class_bolt5x0.py +++ b/tests/unit/async_/io/test_class_bolt5x0.py @@ -278,7 +278,7 @@ async def test_telemetry_message( telemetry_disabled=driver_disabled, ) if serv_enabled: - connection.configuration_hints["telemetry.enabled"] = True + connection.connection_hints["telemetry.enabled"] = True connection.telemetry(api) await connection.send_all() diff --git a/tests/unit/async_/io/test_class_bolt5x1.py b/tests/unit/async_/io/test_class_bolt5x1.py index 003263aa..0eef369e 100644 --- a/tests/unit/async_/io/test_class_bolt5x1.py +++ b/tests/unit/async_/io/test_class_bolt5x1.py @@ -280,7 +280,7 @@ async def test_telemetry_message( telemetry_disabled=driver_disabled, ) if serv_enabled: - connection.configuration_hints["telemetry.enabled"] = True + connection.connection_hints["telemetry.enabled"] = True connection.telemetry(api) await connection.send_all() diff --git a/tests/unit/async_/io/test_class_bolt5x2.py b/tests/unit/async_/io/test_class_bolt5x2.py index 345c9a52..2e25ec1d 100644 --- a/tests/unit/async_/io/test_class_bolt5x2.py +++ b/tests/unit/async_/io/test_class_bolt5x2.py @@ -278,7 +278,7 @@ async def test_telemetry_message( telemetry_disabled=driver_disabled, ) if serv_enabled: - connection.configuration_hints["telemetry.enabled"] = True + connection.connection_hints["telemetry.enabled"] = True connection.telemetry(api) await connection.send_all() diff --git a/tests/unit/async_/io/test_class_bolt5x3.py b/tests/unit/async_/io/test_class_bolt5x3.py index c70a3df4..5d451025 100644 --- a/tests/unit/async_/io/test_class_bolt5x3.py +++ b/tests/unit/async_/io/test_class_bolt5x3.py @@ -281,7 +281,7 @@ async def test_telemetry_message( telemetry_disabled=driver_disabled, ) if serv_enabled: - connection.configuration_hints["telemetry.enabled"] = True + connection.connection_hints["telemetry.enabled"] = True connection.telemetry(api) await connection.send_all() diff --git a/tests/unit/async_/io/test_class_bolt5x4.py b/tests/unit/async_/io/test_class_bolt5x4.py index 7ff21e09..99d9c686 100644 --- a/tests/unit/async_/io/test_class_bolt5x4.py +++ b/tests/unit/async_/io/test_class_bolt5x4.py @@ -281,7 +281,7 @@ async def test_telemetry_message( telemetry_disabled=driver_disabled, ) if serv_enabled: - connection.configuration_hints["telemetry.enabled"] = True + connection.connection_hints["telemetry.enabled"] = True connection.telemetry(api) await connection.send_all() diff --git a/tests/unit/async_/io/test_class_bolt5x5.py b/tests/unit/async_/io/test_class_bolt5x5.py index 77d748de..d1ea0e51 100644 --- a/tests/unit/async_/io/test_class_bolt5x5.py +++ b/tests/unit/async_/io/test_class_bolt5x5.py @@ -281,7 +281,7 @@ async def test_telemetry_message( telemetry_disabled=driver_disabled, ) if serv_enabled: - connection.configuration_hints["telemetry.enabled"] = True + connection.connection_hints["telemetry.enabled"] = True connection.telemetry(api) await connection.send_all() diff --git a/tests/unit/async_/io/test_class_bolt5x6.py b/tests/unit/async_/io/test_class_bolt5x6.py index a5106572..533af97e 100644 --- a/tests/unit/async_/io/test_class_bolt5x6.py +++ b/tests/unit/async_/io/test_class_bolt5x6.py @@ -281,7 +281,7 @@ async def test_telemetry_message( telemetry_disabled=driver_disabled, ) if serv_enabled: - connection.configuration_hints["telemetry.enabled"] = True + connection.connection_hints["telemetry.enabled"] = True connection.telemetry(api) await connection.send_all() diff --git a/tests/unit/async_/io/test_class_bolt5x7.py b/tests/unit/async_/io/test_class_bolt5x7.py index 97a8b4ea..3758157a 100644 --- a/tests/unit/async_/io/test_class_bolt5x7.py +++ b/tests/unit/async_/io/test_class_bolt5x7.py @@ -282,7 +282,7 @@ async def test_telemetry_message( telemetry_disabled=driver_disabled, ) if serv_enabled: - connection.configuration_hints["telemetry.enabled"] = True + connection.connection_hints["telemetry.enabled"] = True connection.telemetry(api) await connection.send_all() diff --git a/tests/unit/async_/io/test_class_bolt5x8.py b/tests/unit/async_/io/test_class_bolt5x8.py new file mode 100644 index 00000000..e3d572d1 --- /dev/null +++ b/tests/unit/async_/io/test_class_bolt5x8.py @@ -0,0 +1,850 @@ +# Copyright (c) "Neo4j" +# Neo4j Sweden AB [https://neo4j.com] +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import itertools +import logging + +import pytest + +import neo4j +from neo4j._api import TelemetryAPI +from neo4j._async.config import AsyncPoolConfig +from neo4j._async.io._bolt5 import AsyncBolt5x8 +from neo4j._meta import ( + BOLT_AGENT_DICT, + USER_AGENT, +) +from neo4j.exceptions import Neo4jError + +from ...._async_compat import mark_async_test +from ....iter_util import powerset + + +@pytest.mark.parametrize("set_stale", (True, False)) +def test_conn_is_stale(fake_socket, set_stale): + address = neo4j.Address(("127.0.0.1", 7687)) + max_connection_lifetime = 0 + connection = AsyncBolt5x8( + address, fake_socket(address), max_connection_lifetime + ) + if set_stale: + connection.set_stale() + assert connection.stale() is True + + +@pytest.mark.parametrize("set_stale", (True, False)) +def test_conn_is_not_stale_if_not_enabled(fake_socket, set_stale): + address = neo4j.Address(("127.0.0.1", 7687)) + max_connection_lifetime = -1 + connection = AsyncBolt5x8( + address, fake_socket(address), max_connection_lifetime + ) + if set_stale: + connection.set_stale() + assert connection.stale() is set_stale + + +@pytest.mark.parametrize("set_stale", (True, False)) +def test_conn_is_not_stale(fake_socket, set_stale): + address = neo4j.Address(("127.0.0.1", 7687)) + max_connection_lifetime = 999999999 + connection = AsyncBolt5x8( + address, fake_socket(address), max_connection_lifetime + ) + if set_stale: + connection.set_stale() + assert connection.stale() is set_stale + + +@pytest.mark.parametrize( + ("args", "kwargs", "expected_fields"), + ( + (("", {}), {"db": "something"}, ({"db": "something"},)), + (("", {}), {"imp_user": "imposter"}, ({"imp_user": "imposter"},)), + ( + ("", {}), + {"db": "something", "imp_user": "imposter"}, + ({"db": "something", "imp_user": "imposter"},), + ), + ), +) +@mark_async_test +async def test_extra_in_begin(fake_socket, args, kwargs, expected_fields): + address = neo4j.Address(("127.0.0.1", 7687)) + socket = fake_socket(address, AsyncBolt5x8.UNPACKER_CLS) + connection = AsyncBolt5x8( + address, socket, AsyncPoolConfig.max_connection_lifetime + ) + connection.begin(*args, **kwargs) + await connection.send_all() + tag, is_fields = await socket.pop_message() + assert tag == b"\x11" + assert tuple(is_fields) == expected_fields + + +@pytest.mark.parametrize( + ("args", "kwargs", "expected_fields"), + ( + (("", {}), {"db": "something"}, ("", {}, {"db": "something"})), + ( + ("", {}), + {"imp_user": "imposter"}, + ("", {}, {"imp_user": "imposter"}), + ), + ( + ("", {}), + {"db": "something", "imp_user": "imposter"}, + ("", {}, {"db": "something", "imp_user": "imposter"}), + ), + ), +) +@mark_async_test +async def test_extra_in_run(fake_socket, args, kwargs, expected_fields): + address = neo4j.Address(("127.0.0.1", 7687)) + socket = fake_socket(address, AsyncBolt5x8.UNPACKER_CLS) + connection = AsyncBolt5x8( + address, socket, AsyncPoolConfig.max_connection_lifetime + ) + connection.run(*args, **kwargs) + await connection.send_all() + tag, is_fields = await socket.pop_message() + assert tag == b"\x10" + assert tuple(is_fields) == expected_fields + + +@mark_async_test +async def test_n_extra_in_discard(fake_socket): + address = neo4j.Address(("127.0.0.1", 7687)) + socket = fake_socket(address, AsyncBolt5x8.UNPACKER_CLS) + connection = AsyncBolt5x8( + address, socket, AsyncPoolConfig.max_connection_lifetime + ) + connection.discard(n=666) + await connection.send_all() + tag, fields = await socket.pop_message() + assert tag == b"\x2f" + assert len(fields) == 1 + assert fields[0] == {"n": 666} + + +@pytest.mark.parametrize( + ("test_input", "expected"), + [ + (666, {"n": -1, "qid": 666}), + (-1, {"n": -1}), + ], +) +@mark_async_test +async def test_qid_extra_in_discard(fake_socket, test_input, expected): + address = neo4j.Address(("127.0.0.1", 7687)) + socket = fake_socket(address, AsyncBolt5x8.UNPACKER_CLS) + connection = AsyncBolt5x8( + address, socket, AsyncPoolConfig.max_connection_lifetime + ) + connection.discard(qid=test_input) + await connection.send_all() + tag, fields = await socket.pop_message() + assert tag == b"\x2f" + assert len(fields) == 1 + assert fields[0] == expected + + +@pytest.mark.parametrize( + ("test_input", "expected"), + [ + (777, {"n": 666, "qid": 777}), + (-1, {"n": 666}), + ], +) +@mark_async_test +async def test_n_and_qid_extras_in_discard(fake_socket, test_input, expected): + address = neo4j.Address(("127.0.0.1", 7687)) + socket = fake_socket(address, AsyncBolt5x8.UNPACKER_CLS) + connection = AsyncBolt5x8( + address, socket, AsyncPoolConfig.max_connection_lifetime + ) + connection.discard(n=666, qid=test_input) + await connection.send_all() + tag, fields = await socket.pop_message() + assert tag == b"\x2f" + assert len(fields) == 1 + assert fields[0] == expected + + +@pytest.mark.parametrize( + ("test_input", "expected"), + [ + (666, {"n": 666}), + (-1, {"n": -1}), + ], +) +@mark_async_test +async def test_n_extra_in_pull(fake_socket, test_input, expected): + address = neo4j.Address(("127.0.0.1", 7687)) + socket = fake_socket(address, AsyncBolt5x8.UNPACKER_CLS) + connection = AsyncBolt5x8( + address, socket, AsyncPoolConfig.max_connection_lifetime + ) + connection.pull(n=test_input) + await connection.send_all() + tag, fields = await socket.pop_message() + assert tag == b"\x3f" + assert len(fields) == 1 + assert fields[0] == expected + + +@pytest.mark.parametrize( + ("test_input", "expected"), + [ + (777, {"n": -1, "qid": 777}), + (-1, {"n": -1}), + ], +) +@mark_async_test +async def test_qid_extra_in_pull(fake_socket, test_input, expected): + address = neo4j.Address(("127.0.0.1", 7687)) + socket = fake_socket(address, AsyncBolt5x8.UNPACKER_CLS) + connection = AsyncBolt5x8( + address, socket, AsyncPoolConfig.max_connection_lifetime + ) + connection.pull(qid=test_input) + await connection.send_all() + tag, fields = await socket.pop_message() + assert tag == b"\x3f" + assert len(fields) == 1 + assert fields[0] == expected + + +@mark_async_test +async def test_n_and_qid_extras_in_pull(fake_socket): + address = neo4j.Address(("127.0.0.1", 7687)) + socket = fake_socket(address, AsyncBolt5x8.UNPACKER_CLS) + connection = AsyncBolt5x8( + address, socket, AsyncPoolConfig.max_connection_lifetime + ) + connection.pull(n=666, qid=777) + await connection.send_all() + tag, fields = await socket.pop_message() + assert tag == b"\x3f" + assert len(fields) == 1 + assert fields[0] == {"n": 666, "qid": 777} + + +@mark_async_test +async def test_hello_passes_routing_metadata(fake_socket_pair): + address = neo4j.Address(("127.0.0.1", 7687)) + sockets = fake_socket_pair( + address, + packer_cls=AsyncBolt5x8.PACKER_CLS, + unpacker_cls=AsyncBolt5x8.UNPACKER_CLS, + ) + await sockets.server.send_message(b"\x70", {"server": "Neo4j/4.4.0"}) + await sockets.server.send_message(b"\x70", {}) + connection = AsyncBolt5x8( + address, + sockets.client, + AsyncPoolConfig.max_connection_lifetime, + routing_context={"foo": "bar"}, + ) + await connection.hello() + tag, fields = await sockets.server.pop_message() + assert tag == b"\x01" + assert len(fields) == 1 + assert fields[0]["routing"] == {"foo": "bar"} + + +@pytest.mark.parametrize("api", TelemetryAPI) +@pytest.mark.parametrize("serv_enabled", (True, False)) +@pytest.mark.parametrize("driver_disabled", (True, False)) +@mark_async_test +async def test_telemetry_message( + fake_socket, api, serv_enabled, driver_disabled +): + address = neo4j.Address(("127.0.0.1", 7687)) + socket = fake_socket(address, AsyncBolt5x8.UNPACKER_CLS) + connection = AsyncBolt5x8( + address, + socket, + AsyncPoolConfig.max_connection_lifetime, + telemetry_disabled=driver_disabled, + ) + if serv_enabled: + connection.connection_hints["telemetry.enabled"] = True + connection.telemetry(api) + await connection.send_all() + + if serv_enabled and not driver_disabled: + tag, fields = await socket.pop_message() + assert tag == b"\x54" + assert fields == [int(api)] + else: + with pytest.raises(OSError): + await socket.pop_message() + + +@pytest.mark.parametrize( + ("hints", "valid"), + ( + ({"connection.recv_timeout_seconds": 1}, True), + ({"connection.recv_timeout_seconds": 42}, True), + ({}, True), + ({"whatever_this_is": "ignore me!"}, True), + ({"connection.recv_timeout_seconds": -1}, False), + ({"connection.recv_timeout_seconds": 0}, False), + ({"connection.recv_timeout_seconds": 2.5}, False), + ({"connection.recv_timeout_seconds": None}, False), + ({"connection.recv_timeout_seconds": False}, False), + ({"connection.recv_timeout_seconds": "1"}, False), + ), +) +@mark_async_test +async def test_hint_recv_timeout_seconds( + fake_socket_pair, hints, valid, caplog, mocker +): + address = neo4j.Address(("127.0.0.1", 7687)) + sockets = fake_socket_pair( + address, + packer_cls=AsyncBolt5x8.PACKER_CLS, + unpacker_cls=AsyncBolt5x8.UNPACKER_CLS, + ) + sockets.client.settimeout = mocker.Mock() + await sockets.server.send_message( + b"\x70", {"server": "Neo4j/4.3.4", "hints": hints} + ) + await sockets.server.send_message(b"\x70", {}) + connection = AsyncBolt5x8( + address, sockets.client, AsyncPoolConfig.max_connection_lifetime + ) + with caplog.at_level(logging.INFO): + await connection.hello() + if valid: + if "connection.recv_timeout_seconds" in hints: + sockets.client.settimeout.assert_called_once_with( + hints["connection.recv_timeout_seconds"] + ) + else: + sockets.client.settimeout.assert_not_called() + assert not any( + "recv_timeout_seconds" in msg and "invalid" in msg + for msg in caplog.messages + ) + else: + sockets.client.settimeout.assert_not_called() + assert any( + repr(hints["connection.recv_timeout_seconds"]) in msg + and "recv_timeout_seconds" in msg + and "invalid" in msg + for msg in caplog.messages + ) + + +CREDENTIALS = "+++super-secret-sauce+++" + + +@pytest.mark.parametrize( + "auth", + ( + ("user", CREDENTIALS), + neo4j.basic_auth("user", CREDENTIALS), + neo4j.kerberos_auth(CREDENTIALS), + neo4j.bearer_auth(CREDENTIALS), + neo4j.custom_auth("user", CREDENTIALS, "realm", "scheme"), + neo4j.Auth("scheme", "principal", CREDENTIALS, "realm", foo="bar"), + ), +) +@mark_async_test +async def test_credentials_are_not_logged(auth, fake_socket_pair, caplog): + address = neo4j.Address(("127.0.0.1", 7687)) + sockets = fake_socket_pair( + address, + packer_cls=AsyncBolt5x8.PACKER_CLS, + unpacker_cls=AsyncBolt5x8.UNPACKER_CLS, + ) + await sockets.server.send_message(b"\x70", {"server": "Neo4j/4.3.4"}) + await sockets.server.send_message(b"\x70", {}) + connection = AsyncBolt5x8( + address, + sockets.client, + AsyncPoolConfig.max_connection_lifetime, + auth=auth, + ) + with caplog.at_level(logging.DEBUG): + await connection.hello() + + if isinstance(auth, tuple): + auth = neo4j.basic_auth(*auth) + for field in ("scheme", "principal", "realm", "parameters"): + value = getattr(auth, field, None) + if value: + assert repr(value) in caplog.text + assert CREDENTIALS not in caplog.text + + +def _assert_notifications_in_extra(extra, expected): + for key in expected: + assert key in extra + assert extra[key] == expected[key] + + +@pytest.mark.parametrize( + ("method", "args", "extra_idx"), + ( + ("run", ("RETURN 1",), 2), + ("begin", (), 0), + ), +) +@pytest.mark.parametrize( + ("cls_min_sev", "method_min_sev"), + itertools.product((None, "WARNING", "OFF"), repeat=2), +) +@pytest.mark.parametrize( + ("cls_dis_clss", "method_dis_clss"), + itertools.product((None, [], ["HINT"], ["HINT", "DEPRECATION"]), repeat=2), +) +@mark_async_test +async def test_supports_notification_filters( + fake_socket, + method, + args, + extra_idx, + cls_min_sev, + method_min_sev, + cls_dis_clss, + method_dis_clss, +): + address = neo4j.Address(("127.0.0.1", 7687)) + socket = fake_socket(address, AsyncBolt5x8.UNPACKER_CLS) + connection = AsyncBolt5x8( + address, + socket, + AsyncPoolConfig.max_connection_lifetime, + notifications_min_severity=cls_min_sev, + notifications_disabled_classifications=cls_dis_clss, + ) + method = getattr(connection, method) + + method( + *args, + notifications_min_severity=method_min_sev, + notifications_disabled_classifications=method_dis_clss, + ) + await connection.send_all() + + _, fields = await socket.pop_message() + extra = fields[extra_idx] + expected = {} + if method_min_sev is not None: + expected["notifications_minimum_severity"] = method_min_sev + if method_dis_clss is not None: + expected["notifications_disabled_classifications"] = method_dis_clss + _assert_notifications_in_extra(extra, expected) + + +@pytest.mark.parametrize("min_sev", (None, "WARNING", "OFF")) +@pytest.mark.parametrize( + "dis_clss", (None, [], ["HINT"], ["HINT", "DEPRECATION"]) +) +@mark_async_test +async def test_hello_supports_notification_filters( + fake_socket_pair, min_sev, dis_clss +): + address = neo4j.Address(("127.0.0.1", 7687)) + sockets = fake_socket_pair( + address, + packer_cls=AsyncBolt5x8.PACKER_CLS, + unpacker_cls=AsyncBolt5x8.UNPACKER_CLS, + ) + await sockets.server.send_message(b"\x70", {"server": "Neo4j/1.2.3"}) + await sockets.server.send_message(b"\x70", {}) + connection = AsyncBolt5x8( + address, + sockets.client, + AsyncPoolConfig.max_connection_lifetime, + notifications_min_severity=min_sev, + notifications_disabled_classifications=dis_clss, + ) + + await connection.hello() + + _tag, fields = await sockets.server.pop_message() + extra = fields[0] + expected = {} + if min_sev is not None: + expected["notifications_minimum_severity"] = min_sev + if dis_clss is not None: + expected["notifications_disabled_classifications"] = dis_clss + _assert_notifications_in_extra(extra, expected) + + +@mark_async_test +@pytest.mark.parametrize( + "user_agent", (None, "test user agent", "", USER_AGENT) +) +async def test_user_agent(fake_socket_pair, user_agent): + address = neo4j.Address(("127.0.0.1", 7687)) + sockets = fake_socket_pair( + address, + packer_cls=AsyncBolt5x8.PACKER_CLS, + unpacker_cls=AsyncBolt5x8.UNPACKER_CLS, + ) + await sockets.server.send_message(b"\x70", {"server": "Neo4j/1.2.3"}) + await sockets.server.send_message(b"\x70", {}) + max_connection_lifetime = 0 + connection = AsyncBolt5x8( + address, sockets.client, max_connection_lifetime, user_agent=user_agent + ) + await connection.hello() + + _tag, fields = await sockets.server.pop_message() + extra = fields[0] + if not user_agent: + assert extra["user_agent"] == USER_AGENT + else: + assert extra["user_agent"] == user_agent + + +@mark_async_test +@pytest.mark.parametrize( + "user_agent", (None, "test user agent", "", USER_AGENT) +) +async def test_sends_bolt_agent(fake_socket_pair, user_agent): + address = neo4j.Address(("127.0.0.1", 7687)) + sockets = fake_socket_pair( + address, + packer_cls=AsyncBolt5x8.PACKER_CLS, + unpacker_cls=AsyncBolt5x8.UNPACKER_CLS, + ) + await sockets.server.send_message(b"\x70", {"server": "Neo4j/1.2.3"}) + await sockets.server.send_message(b"\x70", {}) + max_connection_lifetime = 0 + connection = AsyncBolt5x8( + address, sockets.client, max_connection_lifetime, user_agent=user_agent + ) + await connection.hello() + + _tag, fields = await sockets.server.pop_message() + extra = fields[0] + assert extra["bolt_agent"] == BOLT_AGENT_DICT + + +@mark_async_test +@pytest.mark.parametrize( + ("func", "args", "extra_idx"), + ( + ("run", ("RETURN 1",), 2), + ("begin", (), 0), + ), +) +@pytest.mark.parametrize( + ("timeout", "res"), + ( + (None, None), + (0, 0), + (0.1, 100), + (0.001, 1), + (1e-15, 1), + (0.0005, 1), + (0.0001, 1), + (1.0015, 1002), + (1.000499, 1000), + (1.0025, 1002), + (3.0005, 3000), + (3.456, 3456), + (1, 1000), + ( + -1e-15, + ValueError("Timeout must be a positive number or 0"), + ), + ( + "foo", + ValueError("Timeout must be specified as a number of seconds"), + ), + ( + [1, 2], + TypeError("Timeout must be specified as a number of seconds"), + ), + ), +) +async def test_tx_timeout( + fake_socket_pair, func, args, extra_idx, timeout, res +): + address = neo4j.Address(("127.0.0.1", 7687)) + sockets = fake_socket_pair( + address, + packer_cls=AsyncBolt5x8.PACKER_CLS, + unpacker_cls=AsyncBolt5x8.UNPACKER_CLS, + ) + await sockets.server.send_message(b"\x70", {}) + connection = AsyncBolt5x8(address, sockets.client, 0) + func = getattr(connection, func) + if isinstance(res, Exception): + with pytest.raises(type(res), match=str(res)): + func(*args, timeout=timeout) + else: + func(*args, timeout=timeout) + await connection.send_all() + _tag, fields = await sockets.server.pop_message() + extra = fields[extra_idx] + if timeout is None: + assert "tx_timeout" not in extra + else: + assert extra["tx_timeout"] == res + + +@pytest.mark.parametrize( + "actions", + itertools.combinations_with_replacement( + itertools.product( + ("run", "begin", "begin_run"), + ("reset", "commit", "rollback"), + (None, "some_db", "another_db"), + ), + 2, + ), +) +@mark_async_test +async def test_tracks_last_database(fake_socket_pair, actions): + address = neo4j.Address(("127.0.0.1", 7687)) + sockets = fake_socket_pair( + address, + packer_cls=AsyncBolt5x8.PACKER_CLS, + unpacker_cls=AsyncBolt5x8.UNPACKER_CLS, + ) + connection = AsyncBolt5x8(address, sockets.client, 0) + await sockets.server.send_message(b"\x70", {"server": "Neo4j/1.2.3"}) + await sockets.server.send_message(b"\x70", {}) + await connection.hello() + assert connection.last_database is None + for action, finish, db in actions: + await sockets.server.send_message(b"\x70", {}) + if action == "run": + connection.run("RETURN 1", db=db) + elif action == "begin": + connection.begin(db=db) + elif action == "begin_run": + connection.begin(db=db) + assert connection.last_database == db + await sockets.server.send_message(b"\x70", {}) + connection.run("RETURN 1") + else: + raise ValueError(action) + + assert connection.last_database == db + await connection.send_all() + await connection.fetch_all() + assert connection.last_database == db + + await sockets.server.send_message(b"\x70", {}) + if finish == "reset": + await connection.reset() + elif finish == "commit": + if action == "run": + connection.pull() + else: + connection.commit() + elif finish == "rollback": + if action == "run": + connection.pull() + else: + connection.rollback() + else: + raise ValueError(finish) + + await connection.send_all() + await connection.fetch_all() + + assert connection.last_database == db + + +DEFAULT_DIAG_REC_PAIRS = ( + ("OPERATION", ""), + ("OPERATION_CODE", "0"), + ("CURRENT_SCHEMA", "/"), +) + + +@pytest.mark.parametrize( + "sent_diag_records", + powerset( + ( + ..., + None, + {}, + [], + "1", + 1, + {"OPERATION_CODE": "0"}, + {"OPERATION": "", "OPERATION_CODE": "0", "CURRENT_SCHEMA": "/"}, + {"OPERATION": "Foo", "OPERATION_CODE": 1, "CURRENT_SCHEMA": False}, + {"OPERATION": "", "OPERATION_CODE": "0", "bar": "baz"}, + ), + upper_limit=3, + ), +) +@pytest.mark.parametrize("method", ("pull", "discard")) +@mark_async_test +async def test_enriches_statuses( + sent_diag_records, + method, + fake_socket_pair, +): + address = neo4j.Address(("127.0.0.1", 7687)) + sockets = fake_socket_pair( + address, + packer_cls=AsyncBolt5x8.PACKER_CLS, + unpacker_cls=AsyncBolt5x8.UNPACKER_CLS, + ) + connection = AsyncBolt5x8(address, sockets.client, 0) + + sent_metadata = { + "statuses": [ + { + "status_description": "the status description", + "description": "description", + "diagnostic_record": r, + } + if r is not ... + else { + "status_description": "the status description", + "description": "description", + } + for r in sent_diag_records + ] + } + await sockets.server.send_message(b"\x70", sent_metadata) + + received_metadata = None + + def on_success(metadata): + nonlocal received_metadata + received_metadata = metadata + + getattr(connection, method)(on_success=on_success) + await connection.send_all() + await connection.fetch_all() + + def extend_diag_record(r): + if r is ...: + return dict(DEFAULT_DIAG_REC_PAIRS) + if isinstance(r, dict): + return dict((*DEFAULT_DIAG_REC_PAIRS, *r.items())) + return r + + expected_diag_records = [extend_diag_record(r) for r in sent_diag_records] + expected_metadata = { + "statuses": [ + { + "status_description": "the status description", + "description": "description", + "diagnostic_record": r, + } + if r is not ... + else { + "status_description": "the status description", + "description": "description", + } + for r in expected_diag_records + ] + } + + assert received_metadata == expected_metadata + + +@pytest.mark.parametrize( + "sent_diag_records", + powerset( + ( + ..., + None, + {}, + [], + "1", + 1, + {"OPERATION_CODE": "0"}, + {"OPERATION": "", "OPERATION_CODE": "0", "CURRENT_SCHEMA": "/"}, + {"OPERATION": "Foo", "OPERATION_CODE": 1, "CURRENT_SCHEMA": False}, + {"OPERATION": "", "OPERATION_CODE": "0", "bar": "baz"}, + ), + lower_limit=1, + upper_limit=3, + ), +) +@mark_async_test +async def test_enriches_error_statuses( + sent_diag_records, + fake_socket_pair, +): + address = neo4j.Address(("127.0.0.1", 7687)) + sockets = fake_socket_pair( + address, + packer_cls=AsyncBolt5x8.PACKER_CLS, + unpacker_cls=AsyncBolt5x8.UNPACKER_CLS, + ) + connection = AsyncBolt5x8(address, sockets.client, 0) + sent_diag_records = [ + {**r, "_classification": "CLIENT_ERROR", "_status_parameters": {}} + if isinstance(r, dict) + else r + for r in sent_diag_records + ] + + sent_metadata = _build_error_hierarchy_metadata(sent_diag_records) + + await sockets.server.send_message(b"\x7f", sent_metadata) + + received_metadata = None + + def on_failure(metadata): + nonlocal received_metadata + received_metadata = metadata + + connection.run("RETURN 1", on_failure=on_failure) + await connection.send_all() + with pytest.raises(Neo4jError): + await connection.fetch_all() + + def extend_diag_record(r): + if r is ...: + return dict(DEFAULT_DIAG_REC_PAIRS) + if isinstance(r, dict): + return dict((*DEFAULT_DIAG_REC_PAIRS, *r.items())) + return r + + expected_diag_records = [extend_diag_record(r) for r in sent_diag_records] + expected_metadata = _build_error_hierarchy_metadata(expected_diag_records) + + assert received_metadata == expected_metadata + + +def _build_error_hierarchy_metadata(diag_records_metadata): + metadata = { + "gql_status": "FOO12", + "description": "but have you tried not doing that?!", + "message": "some people just can't be helped", + "neo4j_code": "Neo.ClientError.Generic.YouSuck", + } + if diag_records_metadata[0] is not ...: + metadata["diagnostic_record"] = diag_records_metadata[0] + current_root = metadata + for i, r in enumerate(diag_records_metadata[1:]): + current_root["cause"] = { + "description": f"error cause nr. {i + 1}", + "message": f"cause message {i + 1}", + } + current_root = current_root["cause"] + if r is not ...: + current_root["diagnostic_record"] = r + return metadata diff --git a/tests/unit/async_/work/test_result.py b/tests/unit/async_/work/test_result.py index 1291d95e..6141b998 100644 --- a/tests/unit/async_/work/test_result.py +++ b/tests/unit/async_/work/test_result.py @@ -1366,7 +1366,9 @@ async def test_notification_warning( ] }, ) - result = AsyncResult(connection, 1, warn_notification_severity, noop, noop) + result = AsyncResult( + connection, 1, warn_notification_severity, noop, noop, None + ) if expected_warning is None: with warnings.catch_warnings(): warnings.simplefilter("error") # assert not warnings are emitted @@ -1408,7 +1410,7 @@ async def test_notification_logging( records=Records(["foo"], ()), summary_meta={"notifications": [notification_data]}, ) - result = AsyncResult(connection, 1, None, noop, noop) + result = AsyncResult(connection, 1, None, noop, noop, None) with caplog.at_level(logging.INFO, logger="neo4j.notifications"): await result._run("CYPHER", {}, None, None, "r", None, None, None) await result.consume() diff --git a/tests/unit/async_/work/test_transaction.py b/tests/unit/async_/work/test_transaction.py index 81fba42b..ab23ce7b 100644 --- a/tests/unit/async_/work/test_transaction.py +++ b/tests/unit/async_/work/test_transaction.py @@ -248,7 +248,9 @@ async def test_transaction_no_rollback_on_defunct_connections( async def test_transaction_begin_pipelining( async_fake_connection, pipeline ) -> None: - tx = AsyncTransaction(async_fake_connection, 2, None, noop, noop, noop) + tx = AsyncTransaction( + async_fake_connection, 2, None, noop, noop, noop, None + ) database = "db" imp_user = None bookmarks = ["bookmark1", "bookmark2"] diff --git a/tests/unit/common/work/test_summary.py b/tests/unit/common/work/test_summary.py index 74f2059a..46c85c53 100644 --- a/tests/unit/common/work/test_summary.py +++ b/tests/unit/common/work/test_summary.py @@ -890,6 +890,7 @@ def test_summary_result_counters(summary_args_kwargs, counters_set) -> None: ((5, 5), "t_first"), ((5, 6), "t_first"), ((5, 7), "t_first"), + ((5, 8), "t_first"), ), ) def test_summary_result_available_after( @@ -927,6 +928,7 @@ def test_summary_result_available_after( ((5, 5), "t_last"), ((5, 6), "t_last"), ((5, 7), "t_last"), + ((5, 8), "t_last"), ), ) def test_summary_result_consumed_after( diff --git a/tests/unit/sync/io/test_class_bolt.py b/tests/unit/sync/io/test_class_bolt.py index f3b06303..c3d6bace 100644 --- a/tests/unit/sync/io/test_class_bolt.py +++ b/tests/unit/sync/io/test_class_bolt.py @@ -38,7 +38,7 @@ def test_class_method_protocol_handlers(): expected_handlers = { (3, 0), (4, 1), (4, 2), (4, 3), (4, 4), - (5, 0), (5, 1), (5, 2), (5, 3), (5, 4), (5, 5), (5, 6), (5, 7), + (5, 0), (5, 1), (5, 2), (5, 3), (5, 4), (5, 5), (5, 6), (5, 7), (5, 8), } # fmt: on @@ -69,7 +69,8 @@ def test_class_method_protocol_handlers(): ((5, 5), 1), ((5, 6), 1), ((5, 7), 1), - ((5, 8), 0), + ((5, 8), 1), + ((5, 9), 0), ((6, 0), 0), ], ) @@ -92,7 +93,7 @@ def test_class_method_get_handshake(): handshake = Bolt.get_handshake() assert ( handshake - == b"\x00\x07\x07\x05\x00\x02\x04\x04\x00\x00\x01\x04\x00\x00\x00\x03" + == b"\x00\x08\x08\x05\x00\x02\x04\x04\x00\x00\x01\x04\x00\x00\x00\x03" ) @@ -143,6 +144,7 @@ def test_cancel_hello_in_open(mocker, none_auth): ((5, 5), "neo4j._sync.io._bolt5.Bolt5x5"), ((5, 6), "neo4j._sync.io._bolt5.Bolt5x6"), ((5, 7), "neo4j._sync.io._bolt5.Bolt5x7"), + ((5, 8), "neo4j._sync.io._bolt5.Bolt5x8"), ), ) @mark_sync_test @@ -181,7 +183,7 @@ def test_version_negotiation( (2, 0), (4, 0), (3, 1), - (5, 8), + (5, 9), (6, 0), ), ) diff --git a/tests/unit/sync/io/test_class_bolt3.py b/tests/unit/sync/io/test_class_bolt3.py index ba80ce81..980694bb 100644 --- a/tests/unit/sync/io/test_class_bolt3.py +++ b/tests/unit/sync/io/test_class_bolt3.py @@ -129,7 +129,7 @@ def test_telemetry_message( telemetry_disabled=driver_disabled, ) if serv_enabled: - connection.configuration_hints["telemetry.enabled"] = True + connection.connection_hints["telemetry.enabled"] = True connection.telemetry(api) connection.send_all() diff --git a/tests/unit/sync/io/test_class_bolt4x0.py b/tests/unit/sync/io/test_class_bolt4x0.py index a0ad36e8..be3f4499 100644 --- a/tests/unit/sync/io/test_class_bolt4x0.py +++ b/tests/unit/sync/io/test_class_bolt4x0.py @@ -232,7 +232,7 @@ def test_telemetry_message( telemetry_disabled=driver_disabled, ) if serv_enabled: - connection.configuration_hints["telemetry.enabled"] = True + connection.connection_hints["telemetry.enabled"] = True connection.telemetry(api) connection.send_all() diff --git a/tests/unit/sync/io/test_class_bolt4x1.py b/tests/unit/sync/io/test_class_bolt4x1.py index c4b0208a..77b54513 100644 --- a/tests/unit/sync/io/test_class_bolt4x1.py +++ b/tests/unit/sync/io/test_class_bolt4x1.py @@ -254,7 +254,7 @@ def test_telemetry_message( telemetry_disabled=driver_disabled, ) if serv_enabled: - connection.configuration_hints["telemetry.enabled"] = True + connection.connection_hints["telemetry.enabled"] = True connection.telemetry(api) connection.send_all() diff --git a/tests/unit/sync/io/test_class_bolt4x2.py b/tests/unit/sync/io/test_class_bolt4x2.py index b6ac961a..65525f9e 100644 --- a/tests/unit/sync/io/test_class_bolt4x2.py +++ b/tests/unit/sync/io/test_class_bolt4x2.py @@ -254,7 +254,7 @@ def test_telemetry_message( telemetry_disabled=driver_disabled, ) if serv_enabled: - connection.configuration_hints["telemetry.enabled"] = True + connection.connection_hints["telemetry.enabled"] = True connection.telemetry(api) connection.send_all() diff --git a/tests/unit/sync/io/test_class_bolt4x3.py b/tests/unit/sync/io/test_class_bolt4x3.py index c5da8700..a4a61ba1 100644 --- a/tests/unit/sync/io/test_class_bolt4x3.py +++ b/tests/unit/sync/io/test_class_bolt4x3.py @@ -254,7 +254,7 @@ def test_telemetry_message( telemetry_disabled=driver_disabled, ) if serv_enabled: - connection.configuration_hints["telemetry.enabled"] = True + connection.connection_hints["telemetry.enabled"] = True connection.telemetry(api) connection.send_all() diff --git a/tests/unit/sync/io/test_class_bolt4x4.py b/tests/unit/sync/io/test_class_bolt4x4.py index 164372b0..692f64f8 100644 --- a/tests/unit/sync/io/test_class_bolt4x4.py +++ b/tests/unit/sync/io/test_class_bolt4x4.py @@ -278,7 +278,7 @@ def test_telemetry_message( telemetry_disabled=driver_disabled, ) if serv_enabled: - connection.configuration_hints["telemetry.enabled"] = True + connection.connection_hints["telemetry.enabled"] = True connection.telemetry(api) connection.send_all() diff --git a/tests/unit/sync/io/test_class_bolt5x0.py b/tests/unit/sync/io/test_class_bolt5x0.py index 6f26b97a..5bc3e2c2 100644 --- a/tests/unit/sync/io/test_class_bolt5x0.py +++ b/tests/unit/sync/io/test_class_bolt5x0.py @@ -278,7 +278,7 @@ def test_telemetry_message( telemetry_disabled=driver_disabled, ) if serv_enabled: - connection.configuration_hints["telemetry.enabled"] = True + connection.connection_hints["telemetry.enabled"] = True connection.telemetry(api) connection.send_all() diff --git a/tests/unit/sync/io/test_class_bolt5x1.py b/tests/unit/sync/io/test_class_bolt5x1.py index dfe638a9..4376e39e 100644 --- a/tests/unit/sync/io/test_class_bolt5x1.py +++ b/tests/unit/sync/io/test_class_bolt5x1.py @@ -280,7 +280,7 @@ def test_telemetry_message( telemetry_disabled=driver_disabled, ) if serv_enabled: - connection.configuration_hints["telemetry.enabled"] = True + connection.connection_hints["telemetry.enabled"] = True connection.telemetry(api) connection.send_all() diff --git a/tests/unit/sync/io/test_class_bolt5x2.py b/tests/unit/sync/io/test_class_bolt5x2.py index 5dc09be8..f2d0db48 100644 --- a/tests/unit/sync/io/test_class_bolt5x2.py +++ b/tests/unit/sync/io/test_class_bolt5x2.py @@ -278,7 +278,7 @@ def test_telemetry_message( telemetry_disabled=driver_disabled, ) if serv_enabled: - connection.configuration_hints["telemetry.enabled"] = True + connection.connection_hints["telemetry.enabled"] = True connection.telemetry(api) connection.send_all() diff --git a/tests/unit/sync/io/test_class_bolt5x3.py b/tests/unit/sync/io/test_class_bolt5x3.py index af852710..fecd4d88 100644 --- a/tests/unit/sync/io/test_class_bolt5x3.py +++ b/tests/unit/sync/io/test_class_bolt5x3.py @@ -281,7 +281,7 @@ def test_telemetry_message( telemetry_disabled=driver_disabled, ) if serv_enabled: - connection.configuration_hints["telemetry.enabled"] = True + connection.connection_hints["telemetry.enabled"] = True connection.telemetry(api) connection.send_all() diff --git a/tests/unit/sync/io/test_class_bolt5x4.py b/tests/unit/sync/io/test_class_bolt5x4.py index 5773d1f6..7740449d 100644 --- a/tests/unit/sync/io/test_class_bolt5x4.py +++ b/tests/unit/sync/io/test_class_bolt5x4.py @@ -281,7 +281,7 @@ def test_telemetry_message( telemetry_disabled=driver_disabled, ) if serv_enabled: - connection.configuration_hints["telemetry.enabled"] = True + connection.connection_hints["telemetry.enabled"] = True connection.telemetry(api) connection.send_all() diff --git a/tests/unit/sync/io/test_class_bolt5x5.py b/tests/unit/sync/io/test_class_bolt5x5.py index 361a9c14..c301de17 100644 --- a/tests/unit/sync/io/test_class_bolt5x5.py +++ b/tests/unit/sync/io/test_class_bolt5x5.py @@ -281,7 +281,7 @@ def test_telemetry_message( telemetry_disabled=driver_disabled, ) if serv_enabled: - connection.configuration_hints["telemetry.enabled"] = True + connection.connection_hints["telemetry.enabled"] = True connection.telemetry(api) connection.send_all() diff --git a/tests/unit/sync/io/test_class_bolt5x6.py b/tests/unit/sync/io/test_class_bolt5x6.py index 15f37872..1f61f05e 100644 --- a/tests/unit/sync/io/test_class_bolt5x6.py +++ b/tests/unit/sync/io/test_class_bolt5x6.py @@ -281,7 +281,7 @@ def test_telemetry_message( telemetry_disabled=driver_disabled, ) if serv_enabled: - connection.configuration_hints["telemetry.enabled"] = True + connection.connection_hints["telemetry.enabled"] = True connection.telemetry(api) connection.send_all() diff --git a/tests/unit/sync/io/test_class_bolt5x7.py b/tests/unit/sync/io/test_class_bolt5x7.py index cf999cc6..7d6523ee 100644 --- a/tests/unit/sync/io/test_class_bolt5x7.py +++ b/tests/unit/sync/io/test_class_bolt5x7.py @@ -282,7 +282,7 @@ def test_telemetry_message( telemetry_disabled=driver_disabled, ) if serv_enabled: - connection.configuration_hints["telemetry.enabled"] = True + connection.connection_hints["telemetry.enabled"] = True connection.telemetry(api) connection.send_all() diff --git a/tests/unit/sync/io/test_class_bolt5x8.py b/tests/unit/sync/io/test_class_bolt5x8.py new file mode 100644 index 00000000..f172dc2a --- /dev/null +++ b/tests/unit/sync/io/test_class_bolt5x8.py @@ -0,0 +1,850 @@ +# Copyright (c) "Neo4j" +# Neo4j Sweden AB [https://neo4j.com] +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import itertools +import logging + +import pytest + +import neo4j +from neo4j._api import TelemetryAPI +from neo4j._meta import ( + BOLT_AGENT_DICT, + USER_AGENT, +) +from neo4j._sync.config import PoolConfig +from neo4j._sync.io._bolt5 import Bolt5x8 +from neo4j.exceptions import Neo4jError + +from ...._async_compat import mark_sync_test +from ....iter_util import powerset + + +@pytest.mark.parametrize("set_stale", (True, False)) +def test_conn_is_stale(fake_socket, set_stale): + address = neo4j.Address(("127.0.0.1", 7687)) + max_connection_lifetime = 0 + connection = Bolt5x8( + address, fake_socket(address), max_connection_lifetime + ) + if set_stale: + connection.set_stale() + assert connection.stale() is True + + +@pytest.mark.parametrize("set_stale", (True, False)) +def test_conn_is_not_stale_if_not_enabled(fake_socket, set_stale): + address = neo4j.Address(("127.0.0.1", 7687)) + max_connection_lifetime = -1 + connection = Bolt5x8( + address, fake_socket(address), max_connection_lifetime + ) + if set_stale: + connection.set_stale() + assert connection.stale() is set_stale + + +@pytest.mark.parametrize("set_stale", (True, False)) +def test_conn_is_not_stale(fake_socket, set_stale): + address = neo4j.Address(("127.0.0.1", 7687)) + max_connection_lifetime = 999999999 + connection = Bolt5x8( + address, fake_socket(address), max_connection_lifetime + ) + if set_stale: + connection.set_stale() + assert connection.stale() is set_stale + + +@pytest.mark.parametrize( + ("args", "kwargs", "expected_fields"), + ( + (("", {}), {"db": "something"}, ({"db": "something"},)), + (("", {}), {"imp_user": "imposter"}, ({"imp_user": "imposter"},)), + ( + ("", {}), + {"db": "something", "imp_user": "imposter"}, + ({"db": "something", "imp_user": "imposter"},), + ), + ), +) +@mark_sync_test +def test_extra_in_begin(fake_socket, args, kwargs, expected_fields): + address = neo4j.Address(("127.0.0.1", 7687)) + socket = fake_socket(address, Bolt5x8.UNPACKER_CLS) + connection = Bolt5x8( + address, socket, PoolConfig.max_connection_lifetime + ) + connection.begin(*args, **kwargs) + connection.send_all() + tag, is_fields = socket.pop_message() + assert tag == b"\x11" + assert tuple(is_fields) == expected_fields + + +@pytest.mark.parametrize( + ("args", "kwargs", "expected_fields"), + ( + (("", {}), {"db": "something"}, ("", {}, {"db": "something"})), + ( + ("", {}), + {"imp_user": "imposter"}, + ("", {}, {"imp_user": "imposter"}), + ), + ( + ("", {}), + {"db": "something", "imp_user": "imposter"}, + ("", {}, {"db": "something", "imp_user": "imposter"}), + ), + ), +) +@mark_sync_test +def test_extra_in_run(fake_socket, args, kwargs, expected_fields): + address = neo4j.Address(("127.0.0.1", 7687)) + socket = fake_socket(address, Bolt5x8.UNPACKER_CLS) + connection = Bolt5x8( + address, socket, PoolConfig.max_connection_lifetime + ) + connection.run(*args, **kwargs) + connection.send_all() + tag, is_fields = socket.pop_message() + assert tag == b"\x10" + assert tuple(is_fields) == expected_fields + + +@mark_sync_test +def test_n_extra_in_discard(fake_socket): + address = neo4j.Address(("127.0.0.1", 7687)) + socket = fake_socket(address, Bolt5x8.UNPACKER_CLS) + connection = Bolt5x8( + address, socket, PoolConfig.max_connection_lifetime + ) + connection.discard(n=666) + connection.send_all() + tag, fields = socket.pop_message() + assert tag == b"\x2f" + assert len(fields) == 1 + assert fields[0] == {"n": 666} + + +@pytest.mark.parametrize( + ("test_input", "expected"), + [ + (666, {"n": -1, "qid": 666}), + (-1, {"n": -1}), + ], +) +@mark_sync_test +def test_qid_extra_in_discard(fake_socket, test_input, expected): + address = neo4j.Address(("127.0.0.1", 7687)) + socket = fake_socket(address, Bolt5x8.UNPACKER_CLS) + connection = Bolt5x8( + address, socket, PoolConfig.max_connection_lifetime + ) + connection.discard(qid=test_input) + connection.send_all() + tag, fields = socket.pop_message() + assert tag == b"\x2f" + assert len(fields) == 1 + assert fields[0] == expected + + +@pytest.mark.parametrize( + ("test_input", "expected"), + [ + (777, {"n": 666, "qid": 777}), + (-1, {"n": 666}), + ], +) +@mark_sync_test +def test_n_and_qid_extras_in_discard(fake_socket, test_input, expected): + address = neo4j.Address(("127.0.0.1", 7687)) + socket = fake_socket(address, Bolt5x8.UNPACKER_CLS) + connection = Bolt5x8( + address, socket, PoolConfig.max_connection_lifetime + ) + connection.discard(n=666, qid=test_input) + connection.send_all() + tag, fields = socket.pop_message() + assert tag == b"\x2f" + assert len(fields) == 1 + assert fields[0] == expected + + +@pytest.mark.parametrize( + ("test_input", "expected"), + [ + (666, {"n": 666}), + (-1, {"n": -1}), + ], +) +@mark_sync_test +def test_n_extra_in_pull(fake_socket, test_input, expected): + address = neo4j.Address(("127.0.0.1", 7687)) + socket = fake_socket(address, Bolt5x8.UNPACKER_CLS) + connection = Bolt5x8( + address, socket, PoolConfig.max_connection_lifetime + ) + connection.pull(n=test_input) + connection.send_all() + tag, fields = socket.pop_message() + assert tag == b"\x3f" + assert len(fields) == 1 + assert fields[0] == expected + + +@pytest.mark.parametrize( + ("test_input", "expected"), + [ + (777, {"n": -1, "qid": 777}), + (-1, {"n": -1}), + ], +) +@mark_sync_test +def test_qid_extra_in_pull(fake_socket, test_input, expected): + address = neo4j.Address(("127.0.0.1", 7687)) + socket = fake_socket(address, Bolt5x8.UNPACKER_CLS) + connection = Bolt5x8( + address, socket, PoolConfig.max_connection_lifetime + ) + connection.pull(qid=test_input) + connection.send_all() + tag, fields = socket.pop_message() + assert tag == b"\x3f" + assert len(fields) == 1 + assert fields[0] == expected + + +@mark_sync_test +def test_n_and_qid_extras_in_pull(fake_socket): + address = neo4j.Address(("127.0.0.1", 7687)) + socket = fake_socket(address, Bolt5x8.UNPACKER_CLS) + connection = Bolt5x8( + address, socket, PoolConfig.max_connection_lifetime + ) + connection.pull(n=666, qid=777) + connection.send_all() + tag, fields = socket.pop_message() + assert tag == b"\x3f" + assert len(fields) == 1 + assert fields[0] == {"n": 666, "qid": 777} + + +@mark_sync_test +def test_hello_passes_routing_metadata(fake_socket_pair): + address = neo4j.Address(("127.0.0.1", 7687)) + sockets = fake_socket_pair( + address, + packer_cls=Bolt5x8.PACKER_CLS, + unpacker_cls=Bolt5x8.UNPACKER_CLS, + ) + sockets.server.send_message(b"\x70", {"server": "Neo4j/4.4.0"}) + sockets.server.send_message(b"\x70", {}) + connection = Bolt5x8( + address, + sockets.client, + PoolConfig.max_connection_lifetime, + routing_context={"foo": "bar"}, + ) + connection.hello() + tag, fields = sockets.server.pop_message() + assert tag == b"\x01" + assert len(fields) == 1 + assert fields[0]["routing"] == {"foo": "bar"} + + +@pytest.mark.parametrize("api", TelemetryAPI) +@pytest.mark.parametrize("serv_enabled", (True, False)) +@pytest.mark.parametrize("driver_disabled", (True, False)) +@mark_sync_test +def test_telemetry_message( + fake_socket, api, serv_enabled, driver_disabled +): + address = neo4j.Address(("127.0.0.1", 7687)) + socket = fake_socket(address, Bolt5x8.UNPACKER_CLS) + connection = Bolt5x8( + address, + socket, + PoolConfig.max_connection_lifetime, + telemetry_disabled=driver_disabled, + ) + if serv_enabled: + connection.connection_hints["telemetry.enabled"] = True + connection.telemetry(api) + connection.send_all() + + if serv_enabled and not driver_disabled: + tag, fields = socket.pop_message() + assert tag == b"\x54" + assert fields == [int(api)] + else: + with pytest.raises(OSError): + socket.pop_message() + + +@pytest.mark.parametrize( + ("hints", "valid"), + ( + ({"connection.recv_timeout_seconds": 1}, True), + ({"connection.recv_timeout_seconds": 42}, True), + ({}, True), + ({"whatever_this_is": "ignore me!"}, True), + ({"connection.recv_timeout_seconds": -1}, False), + ({"connection.recv_timeout_seconds": 0}, False), + ({"connection.recv_timeout_seconds": 2.5}, False), + ({"connection.recv_timeout_seconds": None}, False), + ({"connection.recv_timeout_seconds": False}, False), + ({"connection.recv_timeout_seconds": "1"}, False), + ), +) +@mark_sync_test +def test_hint_recv_timeout_seconds( + fake_socket_pair, hints, valid, caplog, mocker +): + address = neo4j.Address(("127.0.0.1", 7687)) + sockets = fake_socket_pair( + address, + packer_cls=Bolt5x8.PACKER_CLS, + unpacker_cls=Bolt5x8.UNPACKER_CLS, + ) + sockets.client.settimeout = mocker.Mock() + sockets.server.send_message( + b"\x70", {"server": "Neo4j/4.3.4", "hints": hints} + ) + sockets.server.send_message(b"\x70", {}) + connection = Bolt5x8( + address, sockets.client, PoolConfig.max_connection_lifetime + ) + with caplog.at_level(logging.INFO): + connection.hello() + if valid: + if "connection.recv_timeout_seconds" in hints: + sockets.client.settimeout.assert_called_once_with( + hints["connection.recv_timeout_seconds"] + ) + else: + sockets.client.settimeout.assert_not_called() + assert not any( + "recv_timeout_seconds" in msg and "invalid" in msg + for msg in caplog.messages + ) + else: + sockets.client.settimeout.assert_not_called() + assert any( + repr(hints["connection.recv_timeout_seconds"]) in msg + and "recv_timeout_seconds" in msg + and "invalid" in msg + for msg in caplog.messages + ) + + +CREDENTIALS = "+++super-secret-sauce+++" + + +@pytest.mark.parametrize( + "auth", + ( + ("user", CREDENTIALS), + neo4j.basic_auth("user", CREDENTIALS), + neo4j.kerberos_auth(CREDENTIALS), + neo4j.bearer_auth(CREDENTIALS), + neo4j.custom_auth("user", CREDENTIALS, "realm", "scheme"), + neo4j.Auth("scheme", "principal", CREDENTIALS, "realm", foo="bar"), + ), +) +@mark_sync_test +def test_credentials_are_not_logged(auth, fake_socket_pair, caplog): + address = neo4j.Address(("127.0.0.1", 7687)) + sockets = fake_socket_pair( + address, + packer_cls=Bolt5x8.PACKER_CLS, + unpacker_cls=Bolt5x8.UNPACKER_CLS, + ) + sockets.server.send_message(b"\x70", {"server": "Neo4j/4.3.4"}) + sockets.server.send_message(b"\x70", {}) + connection = Bolt5x8( + address, + sockets.client, + PoolConfig.max_connection_lifetime, + auth=auth, + ) + with caplog.at_level(logging.DEBUG): + connection.hello() + + if isinstance(auth, tuple): + auth = neo4j.basic_auth(*auth) + for field in ("scheme", "principal", "realm", "parameters"): + value = getattr(auth, field, None) + if value: + assert repr(value) in caplog.text + assert CREDENTIALS not in caplog.text + + +def _assert_notifications_in_extra(extra, expected): + for key in expected: + assert key in extra + assert extra[key] == expected[key] + + +@pytest.mark.parametrize( + ("method", "args", "extra_idx"), + ( + ("run", ("RETURN 1",), 2), + ("begin", (), 0), + ), +) +@pytest.mark.parametrize( + ("cls_min_sev", "method_min_sev"), + itertools.product((None, "WARNING", "OFF"), repeat=2), +) +@pytest.mark.parametrize( + ("cls_dis_clss", "method_dis_clss"), + itertools.product((None, [], ["HINT"], ["HINT", "DEPRECATION"]), repeat=2), +) +@mark_sync_test +def test_supports_notification_filters( + fake_socket, + method, + args, + extra_idx, + cls_min_sev, + method_min_sev, + cls_dis_clss, + method_dis_clss, +): + address = neo4j.Address(("127.0.0.1", 7687)) + socket = fake_socket(address, Bolt5x8.UNPACKER_CLS) + connection = Bolt5x8( + address, + socket, + PoolConfig.max_connection_lifetime, + notifications_min_severity=cls_min_sev, + notifications_disabled_classifications=cls_dis_clss, + ) + method = getattr(connection, method) + + method( + *args, + notifications_min_severity=method_min_sev, + notifications_disabled_classifications=method_dis_clss, + ) + connection.send_all() + + _, fields = socket.pop_message() + extra = fields[extra_idx] + expected = {} + if method_min_sev is not None: + expected["notifications_minimum_severity"] = method_min_sev + if method_dis_clss is not None: + expected["notifications_disabled_classifications"] = method_dis_clss + _assert_notifications_in_extra(extra, expected) + + +@pytest.mark.parametrize("min_sev", (None, "WARNING", "OFF")) +@pytest.mark.parametrize( + "dis_clss", (None, [], ["HINT"], ["HINT", "DEPRECATION"]) +) +@mark_sync_test +def test_hello_supports_notification_filters( + fake_socket_pair, min_sev, dis_clss +): + address = neo4j.Address(("127.0.0.1", 7687)) + sockets = fake_socket_pair( + address, + packer_cls=Bolt5x8.PACKER_CLS, + unpacker_cls=Bolt5x8.UNPACKER_CLS, + ) + sockets.server.send_message(b"\x70", {"server": "Neo4j/1.2.3"}) + sockets.server.send_message(b"\x70", {}) + connection = Bolt5x8( + address, + sockets.client, + PoolConfig.max_connection_lifetime, + notifications_min_severity=min_sev, + notifications_disabled_classifications=dis_clss, + ) + + connection.hello() + + _tag, fields = sockets.server.pop_message() + extra = fields[0] + expected = {} + if min_sev is not None: + expected["notifications_minimum_severity"] = min_sev + if dis_clss is not None: + expected["notifications_disabled_classifications"] = dis_clss + _assert_notifications_in_extra(extra, expected) + + +@mark_sync_test +@pytest.mark.parametrize( + "user_agent", (None, "test user agent", "", USER_AGENT) +) +def test_user_agent(fake_socket_pair, user_agent): + address = neo4j.Address(("127.0.0.1", 7687)) + sockets = fake_socket_pair( + address, + packer_cls=Bolt5x8.PACKER_CLS, + unpacker_cls=Bolt5x8.UNPACKER_CLS, + ) + sockets.server.send_message(b"\x70", {"server": "Neo4j/1.2.3"}) + sockets.server.send_message(b"\x70", {}) + max_connection_lifetime = 0 + connection = Bolt5x8( + address, sockets.client, max_connection_lifetime, user_agent=user_agent + ) + connection.hello() + + _tag, fields = sockets.server.pop_message() + extra = fields[0] + if not user_agent: + assert extra["user_agent"] == USER_AGENT + else: + assert extra["user_agent"] == user_agent + + +@mark_sync_test +@pytest.mark.parametrize( + "user_agent", (None, "test user agent", "", USER_AGENT) +) +def test_sends_bolt_agent(fake_socket_pair, user_agent): + address = neo4j.Address(("127.0.0.1", 7687)) + sockets = fake_socket_pair( + address, + packer_cls=Bolt5x8.PACKER_CLS, + unpacker_cls=Bolt5x8.UNPACKER_CLS, + ) + sockets.server.send_message(b"\x70", {"server": "Neo4j/1.2.3"}) + sockets.server.send_message(b"\x70", {}) + max_connection_lifetime = 0 + connection = Bolt5x8( + address, sockets.client, max_connection_lifetime, user_agent=user_agent + ) + connection.hello() + + _tag, fields = sockets.server.pop_message() + extra = fields[0] + assert extra["bolt_agent"] == BOLT_AGENT_DICT + + +@mark_sync_test +@pytest.mark.parametrize( + ("func", "args", "extra_idx"), + ( + ("run", ("RETURN 1",), 2), + ("begin", (), 0), + ), +) +@pytest.mark.parametrize( + ("timeout", "res"), + ( + (None, None), + (0, 0), + (0.1, 100), + (0.001, 1), + (1e-15, 1), + (0.0005, 1), + (0.0001, 1), + (1.0015, 1002), + (1.000499, 1000), + (1.0025, 1002), + (3.0005, 3000), + (3.456, 3456), + (1, 1000), + ( + -1e-15, + ValueError("Timeout must be a positive number or 0"), + ), + ( + "foo", + ValueError("Timeout must be specified as a number of seconds"), + ), + ( + [1, 2], + TypeError("Timeout must be specified as a number of seconds"), + ), + ), +) +def test_tx_timeout( + fake_socket_pair, func, args, extra_idx, timeout, res +): + address = neo4j.Address(("127.0.0.1", 7687)) + sockets = fake_socket_pair( + address, + packer_cls=Bolt5x8.PACKER_CLS, + unpacker_cls=Bolt5x8.UNPACKER_CLS, + ) + sockets.server.send_message(b"\x70", {}) + connection = Bolt5x8(address, sockets.client, 0) + func = getattr(connection, func) + if isinstance(res, Exception): + with pytest.raises(type(res), match=str(res)): + func(*args, timeout=timeout) + else: + func(*args, timeout=timeout) + connection.send_all() + _tag, fields = sockets.server.pop_message() + extra = fields[extra_idx] + if timeout is None: + assert "tx_timeout" not in extra + else: + assert extra["tx_timeout"] == res + + +@pytest.mark.parametrize( + "actions", + itertools.combinations_with_replacement( + itertools.product( + ("run", "begin", "begin_run"), + ("reset", "commit", "rollback"), + (None, "some_db", "another_db"), + ), + 2, + ), +) +@mark_sync_test +def test_tracks_last_database(fake_socket_pair, actions): + address = neo4j.Address(("127.0.0.1", 7687)) + sockets = fake_socket_pair( + address, + packer_cls=Bolt5x8.PACKER_CLS, + unpacker_cls=Bolt5x8.UNPACKER_CLS, + ) + connection = Bolt5x8(address, sockets.client, 0) + sockets.server.send_message(b"\x70", {"server": "Neo4j/1.2.3"}) + sockets.server.send_message(b"\x70", {}) + connection.hello() + assert connection.last_database is None + for action, finish, db in actions: + sockets.server.send_message(b"\x70", {}) + if action == "run": + connection.run("RETURN 1", db=db) + elif action == "begin": + connection.begin(db=db) + elif action == "begin_run": + connection.begin(db=db) + assert connection.last_database == db + sockets.server.send_message(b"\x70", {}) + connection.run("RETURN 1") + else: + raise ValueError(action) + + assert connection.last_database == db + connection.send_all() + connection.fetch_all() + assert connection.last_database == db + + sockets.server.send_message(b"\x70", {}) + if finish == "reset": + connection.reset() + elif finish == "commit": + if action == "run": + connection.pull() + else: + connection.commit() + elif finish == "rollback": + if action == "run": + connection.pull() + else: + connection.rollback() + else: + raise ValueError(finish) + + connection.send_all() + connection.fetch_all() + + assert connection.last_database == db + + +DEFAULT_DIAG_REC_PAIRS = ( + ("OPERATION", ""), + ("OPERATION_CODE", "0"), + ("CURRENT_SCHEMA", "/"), +) + + +@pytest.mark.parametrize( + "sent_diag_records", + powerset( + ( + ..., + None, + {}, + [], + "1", + 1, + {"OPERATION_CODE": "0"}, + {"OPERATION": "", "OPERATION_CODE": "0", "CURRENT_SCHEMA": "/"}, + {"OPERATION": "Foo", "OPERATION_CODE": 1, "CURRENT_SCHEMA": False}, + {"OPERATION": "", "OPERATION_CODE": "0", "bar": "baz"}, + ), + upper_limit=3, + ), +) +@pytest.mark.parametrize("method", ("pull", "discard")) +@mark_sync_test +def test_enriches_statuses( + sent_diag_records, + method, + fake_socket_pair, +): + address = neo4j.Address(("127.0.0.1", 7687)) + sockets = fake_socket_pair( + address, + packer_cls=Bolt5x8.PACKER_CLS, + unpacker_cls=Bolt5x8.UNPACKER_CLS, + ) + connection = Bolt5x8(address, sockets.client, 0) + + sent_metadata = { + "statuses": [ + { + "status_description": "the status description", + "description": "description", + "diagnostic_record": r, + } + if r is not ... + else { + "status_description": "the status description", + "description": "description", + } + for r in sent_diag_records + ] + } + sockets.server.send_message(b"\x70", sent_metadata) + + received_metadata = None + + def on_success(metadata): + nonlocal received_metadata + received_metadata = metadata + + getattr(connection, method)(on_success=on_success) + connection.send_all() + connection.fetch_all() + + def extend_diag_record(r): + if r is ...: + return dict(DEFAULT_DIAG_REC_PAIRS) + if isinstance(r, dict): + return dict((*DEFAULT_DIAG_REC_PAIRS, *r.items())) + return r + + expected_diag_records = [extend_diag_record(r) for r in sent_diag_records] + expected_metadata = { + "statuses": [ + { + "status_description": "the status description", + "description": "description", + "diagnostic_record": r, + } + if r is not ... + else { + "status_description": "the status description", + "description": "description", + } + for r in expected_diag_records + ] + } + + assert received_metadata == expected_metadata + + +@pytest.mark.parametrize( + "sent_diag_records", + powerset( + ( + ..., + None, + {}, + [], + "1", + 1, + {"OPERATION_CODE": "0"}, + {"OPERATION": "", "OPERATION_CODE": "0", "CURRENT_SCHEMA": "/"}, + {"OPERATION": "Foo", "OPERATION_CODE": 1, "CURRENT_SCHEMA": False}, + {"OPERATION": "", "OPERATION_CODE": "0", "bar": "baz"}, + ), + lower_limit=1, + upper_limit=3, + ), +) +@mark_sync_test +def test_enriches_error_statuses( + sent_diag_records, + fake_socket_pair, +): + address = neo4j.Address(("127.0.0.1", 7687)) + sockets = fake_socket_pair( + address, + packer_cls=Bolt5x8.PACKER_CLS, + unpacker_cls=Bolt5x8.UNPACKER_CLS, + ) + connection = Bolt5x8(address, sockets.client, 0) + sent_diag_records = [ + {**r, "_classification": "CLIENT_ERROR", "_status_parameters": {}} + if isinstance(r, dict) + else r + for r in sent_diag_records + ] + + sent_metadata = _build_error_hierarchy_metadata(sent_diag_records) + + sockets.server.send_message(b"\x7f", sent_metadata) + + received_metadata = None + + def on_failure(metadata): + nonlocal received_metadata + received_metadata = metadata + + connection.run("RETURN 1", on_failure=on_failure) + connection.send_all() + with pytest.raises(Neo4jError): + connection.fetch_all() + + def extend_diag_record(r): + if r is ...: + return dict(DEFAULT_DIAG_REC_PAIRS) + if isinstance(r, dict): + return dict((*DEFAULT_DIAG_REC_PAIRS, *r.items())) + return r + + expected_diag_records = [extend_diag_record(r) for r in sent_diag_records] + expected_metadata = _build_error_hierarchy_metadata(expected_diag_records) + + assert received_metadata == expected_metadata + + +def _build_error_hierarchy_metadata(diag_records_metadata): + metadata = { + "gql_status": "FOO12", + "description": "but have you tried not doing that?!", + "message": "some people just can't be helped", + "neo4j_code": "Neo.ClientError.Generic.YouSuck", + } + if diag_records_metadata[0] is not ...: + metadata["diagnostic_record"] = diag_records_metadata[0] + current_root = metadata + for i, r in enumerate(diag_records_metadata[1:]): + current_root["cause"] = { + "description": f"error cause nr. {i + 1}", + "message": f"cause message {i + 1}", + } + current_root = current_root["cause"] + if r is not ...: + current_root["diagnostic_record"] = r + return metadata diff --git a/tests/unit/sync/work/test_result.py b/tests/unit/sync/work/test_result.py index 623d5014..6bb71e8b 100644 --- a/tests/unit/sync/work/test_result.py +++ b/tests/unit/sync/work/test_result.py @@ -1366,7 +1366,9 @@ def test_notification_warning( ] }, ) - result = Result(connection, 1, warn_notification_severity, noop, noop) + result = Result( + connection, 1, warn_notification_severity, noop, noop, None + ) if expected_warning is None: with warnings.catch_warnings(): warnings.simplefilter("error") # assert not warnings are emitted @@ -1408,7 +1410,7 @@ def test_notification_logging( records=Records(["foo"], ()), summary_meta={"notifications": [notification_data]}, ) - result = Result(connection, 1, None, noop, noop) + result = Result(connection, 1, None, noop, noop, None) with caplog.at_level(logging.INFO, logger="neo4j.notifications"): result._run("CYPHER", {}, None, None, "r", None, None, None) result.consume() diff --git a/tests/unit/sync/work/test_transaction.py b/tests/unit/sync/work/test_transaction.py index 53aeba1e..683ff45d 100644 --- a/tests/unit/sync/work/test_transaction.py +++ b/tests/unit/sync/work/test_transaction.py @@ -248,7 +248,9 @@ def test_transaction_no_rollback_on_defunct_connections( def test_transaction_begin_pipelining( fake_connection, pipeline ) -> None: - tx = Transaction(fake_connection, 2, None, noop, noop, noop) + tx = Transaction( + fake_connection, 2, None, noop, noop, noop, None + ) database = "db" imp_user = None bookmarks = ["bookmark1", "bookmark2"] From 4db49eb4a57f3c9cb05c5a5d56e29c0f7d34c1b6 Mon Sep 17 00:00:00 2001 From: Robsdedude Date: Mon, 4 Nov 2024 12:20:25 +0100 Subject: [PATCH 03/26] Improve logging --- src/neo4j/_async/work/workspace.py | 4 ++++ src/neo4j/_sync/work/workspace.py | 4 ++++ 2 files changed, 8 insertions(+) diff --git a/src/neo4j/_async/work/workspace.py b/src/neo4j/_async/work/workspace.py index 22ecf1ed..a3066f1c 100644 --- a/src/neo4j/_async/work/workspace.py +++ b/src/neo4j/_async/work/workspace.py @@ -219,6 +219,10 @@ async def _connect(self, access_mode, auth=None, **acquire_kwargs) -> None: # race condition: in the meantime, the pool added a connection, # which does not support SSR. # => we need to fall back to explicit home database resolution + log.debug( + "[#0000] _: detected ssr support race; " + "falling back to explicit home database resolution", + ) await self._disconnect() routing_target = await self._get_routing_target_database( acquire_auth, ssr_enabled=False diff --git a/src/neo4j/_sync/work/workspace.py b/src/neo4j/_sync/work/workspace.py index e13ba382..bc979988 100644 --- a/src/neo4j/_sync/work/workspace.py +++ b/src/neo4j/_sync/work/workspace.py @@ -216,6 +216,10 @@ def _connect(self, access_mode, auth=None, **acquire_kwargs) -> None: # race condition: in the meantime, the pool added a connection, # which does not support SSR. # => we need to fall back to explicit home database resolution + log.debug( + "[#0000] _: detected ssr support race; " + "falling back to explicit home database resolution", + ) self._disconnect() routing_target = self._get_routing_target_database( acquire_auth, ssr_enabled=False From 83e3ebe54ed108b439dc6e385bfe4020fc034c7a Mon Sep 17 00:00:00 2001 From: Robsdedude Date: Mon, 4 Nov 2024 12:21:06 +0100 Subject: [PATCH 04/26] Clean up code --- src/neo4j/_async/work/workspace.py | 9 ++------- src/neo4j/_sync/work/workspace.py | 9 ++------- 2 files changed, 4 insertions(+), 14 deletions(-) diff --git a/src/neo4j/_async/work/workspace.py b/src/neo4j/_async/work/workspace.py index a3066f1c..e76cfe4f 100644 --- a/src/neo4j/_async/work/workspace.py +++ b/src/neo4j/_async/work/workspace.py @@ -34,10 +34,7 @@ SessionExpired, ) from .._debug import AsyncNonConcurrentMethodChecker -from ..io import ( - AcquireAuth, - AsyncNeo4jPool, -) +from ..io import AcquireAuth if t.TYPE_CHECKING: @@ -236,9 +233,7 @@ async def _get_routing_target_database( acquire_auth: AcquireAuth, ssr_enabled: bool, ) -> _TargetDatabase: - if self._config.database is not None or not isinstance( - self._pool, AsyncNeo4jPool - ): + if self._config.database is not None or self._pool.is_direct_pool: self._set_pinned_database(self._config.database) log.debug( "[#0000] _: routing towards fixed database: %s", diff --git a/src/neo4j/_sync/work/workspace.py b/src/neo4j/_sync/work/workspace.py index bc979988..d84938fb 100644 --- a/src/neo4j/_sync/work/workspace.py +++ b/src/neo4j/_sync/work/workspace.py @@ -34,10 +34,7 @@ SessionExpired, ) from .._debug import NonConcurrentMethodChecker -from ..io import ( - AcquireAuth, - Neo4jPool, -) +from ..io import AcquireAuth if t.TYPE_CHECKING: @@ -233,9 +230,7 @@ def _get_routing_target_database( acquire_auth: AcquireAuth, ssr_enabled: bool, ) -> _TargetDatabase: - if self._config.database is not None or not isinstance( - self._pool, Neo4jPool - ): + if self._config.database is not None or self._pool.is_direct_pool: self._set_pinned_database(self._config.database) log.debug( "[#0000] _: routing towards fixed database: %s", From 96a30f8be1698596cca18cd74bffa225949bc369 Mon Sep 17 00:00:00 2001 From: Robsdedude Date: Mon, 4 Nov 2024 12:21:42 +0100 Subject: [PATCH 05/26] Remove unused TestKit feature from backend The feature is a left-over from the old home db cache spike --- testkitbackend/test_config.json | 1 - 1 file changed, 1 deletion(-) diff --git a/testkitbackend/test_config.json b/testkitbackend/test_config.json index b4b357a6..bca7f0ca 100644 --- a/testkitbackend/test_config.json +++ b/testkitbackend/test_config.json @@ -61,7 +61,6 @@ "Feature:Bolt:5.7": true, "Feature:Bolt:Patch:UTC": true, "Feature:Impersonation": true, - "Feature:HomeDbCache": true, "Feature:TLS:1.1": "Driver blocks TLS 1.1 for security reasons.", "Feature:TLS:1.2": true, "Feature:TLS:1.3": "Depends on the machine (will be calculated dynamically).", From ec646d43f20a9be6844478c9401edede39dfe13a Mon Sep 17 00:00:00 2001 From: Robsdedude Date: Mon, 11 Nov 2024 09:31:44 +0100 Subject: [PATCH 06/26] Don't send db name to server when taken from cache --- src/neo4j/_async/io/__init__.py | 6 ++- src/neo4j/_async/io/_pool.py | 74 ++++++++++++++++++++--------- src/neo4j/_async/work/session.py | 4 +- src/neo4j/_async/work/workspace.py | 54 ++++++++------------- src/neo4j/_sync/io/__init__.py | 6 ++- src/neo4j/_sync/io/_pool.py | 74 ++++++++++++++++++++--------- src/neo4j/_sync/work/session.py | 4 +- src/neo4j/_sync/work/workspace.py | 54 ++++++++------------- testkitbackend/test_config.json | 1 + tests/unit/async_/io/test_direct.py | 1 + tests/unit/mixed/io/test_direct.py | 14 +++--- tests/unit/sync/io/test_direct.py | 1 + 12 files changed, 165 insertions(+), 128 deletions(-) diff --git a/src/neo4j/_async/io/__init__.py b/src/neo4j/_async/io/__init__.py index 3571ad94..7f068da7 100644 --- a/src/neo4j/_async/io/__init__.py +++ b/src/neo4j/_async/io/__init__.py @@ -22,7 +22,8 @@ """ __all__ = [ - "AcquireAuth", + "AcquisitionAuth", + "AcquisitionDatabase", "AsyncBolt", "AsyncBoltPool", "AsyncNeo4jPool", @@ -37,7 +38,8 @@ ConnectionErrorHandler, ) from ._pool import ( - AcquireAuth, + AcquisitionAuth, + AcquisitionDatabase, AsyncBoltPool, AsyncNeo4jPool, ) diff --git a/src/neo4j/_async/io/_pool.py b/src/neo4j/_async/io/_pool.py index fbcae249..96a7f234 100644 --- a/src/neo4j/_async/io/_pool.py +++ b/src/neo4j/_async/io/_pool.py @@ -75,11 +75,17 @@ @dataclass -class AcquireAuth: +class AcquisitionAuth: auth: AsyncAuthManager | AuthManager | None force_auth: bool = False +@dataclass +class AcquisitionDatabase: + name: str | None + guessed: bool = False + + @dataclass class ConnectionFeatureTracker: feature_check: t.Callable[[AsyncBolt], bool] @@ -315,7 +321,7 @@ async def _acquire(self, address, auth, deadline, liveness_check_timeout): This method is thread safe. """ if auth is None: - auth = AcquireAuth(None) + auth = AcquisitionAuth(None) force_auth = auth.force_auth auth = auth.auth if liveness_check_timeout is None: @@ -410,8 +416,9 @@ async def acquire( timeout, database, bookmarks, - auth: AcquireAuth, + auth: AcquisitionAuth, liveness_check_timeout, + database_callback=None, ): """ Acquire a connection to a server that can satisfy a set of parameters. @@ -423,6 +430,7 @@ async def acquire( :param bookmarks: :param auth: :param liveness_check_timeout: + :param database_callback: """ ... @@ -640,8 +648,9 @@ async def acquire( timeout, database, bookmarks, - auth: AcquireAuth, + auth: AcquisitionAuth, liveness_check_timeout, + database_callback=None, ): # The access_mode and database is not needed for a direct connection, # it's just there for consistency. @@ -731,6 +740,10 @@ async def get_or_create_routing_table(self, database): ) return self.routing_tables[database] + async def get_routing_table(self, database): + async with self.refresh_lock: + return self.routing_tables.get(database) + async def fetch_routing_info( self, address, database, imp_user, bookmarks, auth, acquisition_timeout ): @@ -942,13 +955,16 @@ async def update_routing_table( :raise neo4j.exceptions.ServiceUnavailable: """ async with self.refresh_lock: - routing_table = await self.get_or_create_routing_table(database) - # copied because it can be modified - existing_routers = set(routing_table.routers) - - prefer_initial_routing_address = self.routing_tables[ - database - ].initialized_without_writers + routing_table = await self.get_routing_table(database) + if routing_table is not None: + # copied because it can be modified + existing_routers = set(routing_table.routers) + prefer_initial_routing_address = ( + routing_table.initialized_without_writers + ) + else: + existing_routers = {self.address} + prefer_initial_routing_address = True if ( prefer_initial_routing_address @@ -998,12 +1014,17 @@ async def update_routing_table( async def update_connection_pool(self, *, database): async with self.refresh_lock: - routing_tables = [await self.get_or_create_routing_table(database)] + rt = await self.get_routing_table(database) + routing_tables = [rt] if rt is not None else [] for db in self.routing_tables: if db == database: continue routing_tables.append(self.routing_tables[db]) - servers = set.union(*(rt.servers() for rt in routing_tables)) + + servers = set.union( + *(rt.servers() for rt in routing_tables), + self.address, + ) for address in list(self.connections): if address._unresolved not in servers: await super().deactivate(address) @@ -1012,13 +1033,13 @@ async def ensure_routing_table_is_fresh( self, *, access_mode, - database, + database: AcquisitionDatabase, imp_user, bookmarks, auth=None, acquisition_timeout=None, database_callback=None, - ): + ) -> bool: """ Update the routing table if stale. @@ -1050,8 +1071,10 @@ async def ensure_routing_table_is_fresh( ) del self.routing_tables[database_] - routing_table = await self.get_or_create_routing_table(database) - if routing_table.is_fresh(readonly=(access_mode == READ_ACCESS)): + routing_table = await self.get_routing_table(database.name) + if routing_table is not None and routing_table.is_fresh( + readonly=(access_mode == READ_ACCESS) + ): # table is still valid log.debug( "[#0000] _: using existing routing table %r", @@ -1059,15 +1082,18 @@ async def ensure_routing_table_is_fresh( ) return False + async def wrapped_database_callback(database: str | None) -> None: + await AsyncUtil.callback(database_callback, database) + await self.update_connection_pool(database=database) + await self.update_routing_table( - database=database, + database=database.name if not database.guessed else None, imp_user=imp_user, bookmarks=bookmarks, auth=auth, acquisition_timeout=acquisition_timeout, - database_callback=database_callback, + database_callback=wrapped_database_callback, ) - await self.update_connection_pool(database=database) return True @@ -1104,10 +1130,11 @@ async def acquire( self, access_mode, timeout, - database, + database: AcquisitionDatabase, bookmarks, - auth: AcquireAuth | None, + auth: AcquisitionAuth | None, liveness_check_timeout, + database_callback=None, ): if access_mode not in {WRITE_ACCESS, READ_ACCESS}: # TODO: 6.0 - change this to be a ValueError @@ -1139,6 +1166,7 @@ async def acquire( bookmarks=bookmarks, auth=auth, acquisition_timeout=timeout, + database_callback=database_callback, ) while True: @@ -1146,7 +1174,7 @@ async def acquire( # Get an address for a connection that have the fewest in-use # connections. address = await self._select_address( - access_mode=access_mode, database=database + access_mode=access_mode, database=database.name ) except (ReadServiceUnavailable, WriteServiceUnavailable) as err: raise SessionExpired( diff --git a/src/neo4j/_async/work/session.py b/src/neo4j/_async/work/session.py index 264b31f3..dd7324ab 100644 --- a/src/neo4j/_async/work/session.py +++ b/src/neo4j/_async/work/session.py @@ -321,7 +321,7 @@ async def run( self._config.warn_notification_severity, self._result_closed, self._result_error, - self._make_query_database_resolution_callback(), + self._make_db_resolution_callback(), ) bookmarks = await self._get_bookmarks() parameters = dict(parameters or {}, **kwargs) @@ -449,7 +449,7 @@ async def _open_transaction( self._transaction_closed_handler, self._transaction_error_handler, self._transaction_cancel_handler, - self._make_query_database_resolution_callback(), + self._make_db_resolution_callback(), ) bookmarks = await self._get_bookmarks() await self._transaction._begin( diff --git a/src/neo4j/_async/work/workspace.py b/src/neo4j/_async/work/workspace.py index e76cfe4f..e7859107 100644 --- a/src/neo4j/_async/work/workspace.py +++ b/src/neo4j/_async/work/workspace.py @@ -18,7 +18,6 @@ import logging import typing as t -from dataclasses import dataclass from ..._async_compat.util import AsyncUtil from ..._auth_management import to_auth_dict @@ -34,7 +33,10 @@ SessionExpired, ) from .._debug import AsyncNonConcurrentMethodChecker -from ..io import AcquireAuth +from ..io import ( + AcquisitionAuth, + AcquisitionDatabase, +) if t.TYPE_CHECKING: @@ -54,12 +56,6 @@ log = logging.getLogger("neo4j") -@dataclass -class _TargetDatabase: - database: str | None - from_cache: bool = False - - class AsyncWorkspace(AsyncNonConcurrentMethodChecker): def __init__(self, pool, config): assert isinstance(config, WorkspaceConfig) @@ -108,21 +104,10 @@ async def __aenter__(self) -> AsyncWorkspace: async def __aexit__(self, exc_type, exc_value, traceback): await self.close() - def _make_routing_database_callback( - self, - cache_key: TKey, - ) -> t.Callable[[str], None]: - def _database_callback(database: str | None) -> None: - if not self._pinned_database: - self._set_pinned_database(database) - db_cache: AsyncHomeDbCache = self._pool.home_db_cache - db_cache.set(cache_key, database) - - return _database_callback + def _make_db_resolution_callback(self) -> t.Callable[[str], None] | None: + if self._pinned_database: + return None - def _make_query_database_resolution_callback( - self, - ) -> t.Callable[[str], None] | None: def _database_callback(database: str | None) -> None: if not self._pinned_database: self._set_pinned_database(database) @@ -187,7 +172,7 @@ async def _update_bookmark(self, bookmark): async def _connect(self, access_mode, auth=None, **acquire_kwargs) -> None: acquisition_timeout = self._config.connection_acquisition_timeout force_auth = acquire_kwargs.pop("force_auth", False) - acquire_auth = AcquireAuth(auth, force_auth=force_auth) + acquire_auth = AcquisitionAuth(auth, force_auth=force_auth) if self._connection: # TODO: Investigate this @@ -197,20 +182,21 @@ async def _connect(self, access_mode, auth=None, **acquire_kwargs) -> None: await self._disconnect() ssr_enabled = self._pool.ssr_enabled - routing_target = await self._get_routing_target_database( + target_db = await self._get_routing_target_database( acquire_auth, ssr_enabled=ssr_enabled ) acquire_kwargs_ = { "access_mode": access_mode, "timeout": acquisition_timeout, - "database": routing_target.database, + "database": target_db, "bookmarks": await self._get_bookmarks(), "auth": acquire_auth, "liveness_check_timeout": None, + "database_callback": self._make_db_resolution_callback(), } acquire_kwargs_.update(acquire_kwargs) self._connection = await self._pool.acquire(**acquire_kwargs_) - if routing_target.from_cache and ( + if target_db.guessed and ( not self._pool.ssr_enabled or not self._connection.ssr_enabled ): # race condition: in the meantime, the pool added a connection, @@ -221,25 +207,25 @@ async def _connect(self, access_mode, auth=None, **acquire_kwargs) -> None: "falling back to explicit home database resolution", ) await self._disconnect() - routing_target = await self._get_routing_target_database( + target_db = await self._get_routing_target_database( acquire_auth, ssr_enabled=False ) - acquire_kwargs_["database"] = routing_target.database + acquire_kwargs_["database"] = target_db self._connection = await self._pool.acquire(**acquire_kwargs_) self._connection_access_mode = access_mode async def _get_routing_target_database( self, - acquire_auth: AcquireAuth, + acquire_auth: AcquisitionAuth, ssr_enabled: bool, - ) -> _TargetDatabase: + ) -> AcquisitionDatabase: if self._config.database is not None or self._pool.is_direct_pool: self._set_pinned_database(self._config.database) log.debug( "[#0000] _: routing towards fixed database: %s", self._config.database, ) - return _TargetDatabase(self._config.database) + return AcquisitionDatabase(self._config.database) auth = acquire_auth.auth resolved_auth = await self._resolve_session_auth(auth) @@ -260,7 +246,7 @@ async def _get_routing_target_database( ), cached_db, ) - return _TargetDatabase(cached_db, from_cache=True) + return AcquisitionDatabase(cached_db, guessed=True) acquisition_timeout = self._config.connection_acquisition_timeout log.debug("[#0000] _: resolve home database") @@ -270,9 +256,9 @@ async def _get_routing_target_database( bookmarks=await self._get_bookmarks(), auth=acquire_auth, acquisition_timeout=acquisition_timeout, - database_callback=self._make_routing_database_callback(cache_key), + database_callback=self._make_db_resolution_callback(), ) - return _TargetDatabase(self._config.database) + return AcquisitionDatabase(self._config.database) @staticmethod async def _resolve_session_auth( diff --git a/src/neo4j/_sync/io/__init__.py b/src/neo4j/_sync/io/__init__.py index a1833c74..5a7ea831 100644 --- a/src/neo4j/_sync/io/__init__.py +++ b/src/neo4j/_sync/io/__init__.py @@ -22,7 +22,8 @@ """ __all__ = [ - "AcquireAuth", + "AcquisitionAuth", + "AcquisitionDatabase", "Bolt", "BoltPool", "Neo4jPool", @@ -37,7 +38,8 @@ ConnectionErrorHandler, ) from ._pool import ( - AcquireAuth, + AcquisitionAuth, + AcquisitionDatabase, BoltPool, Neo4jPool, ) diff --git a/src/neo4j/_sync/io/_pool.py b/src/neo4j/_sync/io/_pool.py index e5b406cf..8dbbc164 100644 --- a/src/neo4j/_sync/io/_pool.py +++ b/src/neo4j/_sync/io/_pool.py @@ -72,11 +72,17 @@ @dataclass -class AcquireAuth: +class AcquisitionAuth: auth: AuthManager | AuthManager | None force_auth: bool = False +@dataclass +class AcquisitionDatabase: + name: str | None + guessed: bool = False + + @dataclass class ConnectionFeatureTracker: feature_check: t.Callable[[Bolt], bool] @@ -312,7 +318,7 @@ def _acquire(self, address, auth, deadline, liveness_check_timeout): This method is thread safe. """ if auth is None: - auth = AcquireAuth(None) + auth = AcquisitionAuth(None) force_auth = auth.force_auth auth = auth.auth if liveness_check_timeout is None: @@ -407,8 +413,9 @@ def acquire( timeout, database, bookmarks, - auth: AcquireAuth, + auth: AcquisitionAuth, liveness_check_timeout, + database_callback=None, ): """ Acquire a connection to a server that can satisfy a set of parameters. @@ -420,6 +427,7 @@ def acquire( :param bookmarks: :param auth: :param liveness_check_timeout: + :param database_callback: """ ... @@ -637,8 +645,9 @@ def acquire( timeout, database, bookmarks, - auth: AcquireAuth, + auth: AcquisitionAuth, liveness_check_timeout, + database_callback=None, ): # The access_mode and database is not needed for a direct connection, # it's just there for consistency. @@ -728,6 +737,10 @@ def get_or_create_routing_table(self, database): ) return self.routing_tables[database] + def get_routing_table(self, database): + with self.refresh_lock: + return self.routing_tables.get(database) + def fetch_routing_info( self, address, database, imp_user, bookmarks, auth, acquisition_timeout ): @@ -939,13 +952,16 @@ def update_routing_table( :raise neo4j.exceptions.ServiceUnavailable: """ with self.refresh_lock: - routing_table = self.get_or_create_routing_table(database) - # copied because it can be modified - existing_routers = set(routing_table.routers) - - prefer_initial_routing_address = self.routing_tables[ - database - ].initialized_without_writers + routing_table = self.get_routing_table(database) + if routing_table is not None: + # copied because it can be modified + existing_routers = set(routing_table.routers) + prefer_initial_routing_address = ( + routing_table.initialized_without_writers + ) + else: + existing_routers = {self.address} + prefer_initial_routing_address = True if ( prefer_initial_routing_address @@ -995,12 +1011,17 @@ def update_routing_table( def update_connection_pool(self, *, database): with self.refresh_lock: - routing_tables = [self.get_or_create_routing_table(database)] + rt = self.get_routing_table(database) + routing_tables = [rt] if rt is not None else [] for db in self.routing_tables: if db == database: continue routing_tables.append(self.routing_tables[db]) - servers = set.union(*(rt.servers() for rt in routing_tables)) + + servers = set.union( + *(rt.servers() for rt in routing_tables), + self.address, + ) for address in list(self.connections): if address._unresolved not in servers: super().deactivate(address) @@ -1009,13 +1030,13 @@ def ensure_routing_table_is_fresh( self, *, access_mode, - database, + database: AcquisitionDatabase, imp_user, bookmarks, auth=None, acquisition_timeout=None, database_callback=None, - ): + ) -> bool: """ Update the routing table if stale. @@ -1047,8 +1068,10 @@ def ensure_routing_table_is_fresh( ) del self.routing_tables[database_] - routing_table = self.get_or_create_routing_table(database) - if routing_table.is_fresh(readonly=(access_mode == READ_ACCESS)): + routing_table = self.get_routing_table(database.name) + if routing_table is not None and routing_table.is_fresh( + readonly=(access_mode == READ_ACCESS) + ): # table is still valid log.debug( "[#0000] _: using existing routing table %r", @@ -1056,15 +1079,18 @@ def ensure_routing_table_is_fresh( ) return False + def wrapped_database_callback(database: str | None) -> None: + Util.callback(database_callback, database) + self.update_connection_pool(database=database) + self.update_routing_table( - database=database, + database=database.name if not database.guessed else None, imp_user=imp_user, bookmarks=bookmarks, auth=auth, acquisition_timeout=acquisition_timeout, - database_callback=database_callback, + database_callback=wrapped_database_callback, ) - self.update_connection_pool(database=database) return True @@ -1101,10 +1127,11 @@ def acquire( self, access_mode, timeout, - database, + database: AcquisitionDatabase, bookmarks, - auth: AcquireAuth | None, + auth: AcquisitionAuth | None, liveness_check_timeout, + database_callback=None, ): if access_mode not in {WRITE_ACCESS, READ_ACCESS}: # TODO: 6.0 - change this to be a ValueError @@ -1136,6 +1163,7 @@ def acquire( bookmarks=bookmarks, auth=auth, acquisition_timeout=timeout, + database_callback=database_callback, ) while True: @@ -1143,7 +1171,7 @@ def acquire( # Get an address for a connection that have the fewest in-use # connections. address = self._select_address( - access_mode=access_mode, database=database + access_mode=access_mode, database=database.name ) except (ReadServiceUnavailable, WriteServiceUnavailable) as err: raise SessionExpired( diff --git a/src/neo4j/_sync/work/session.py b/src/neo4j/_sync/work/session.py index 99f77a4f..910fe328 100644 --- a/src/neo4j/_sync/work/session.py +++ b/src/neo4j/_sync/work/session.py @@ -321,7 +321,7 @@ def run( self._config.warn_notification_severity, self._result_closed, self._result_error, - self._make_query_database_resolution_callback(), + self._make_db_resolution_callback(), ) bookmarks = self._get_bookmarks() parameters = dict(parameters or {}, **kwargs) @@ -449,7 +449,7 @@ def _open_transaction( self._transaction_closed_handler, self._transaction_error_handler, self._transaction_cancel_handler, - self._make_query_database_resolution_callback(), + self._make_db_resolution_callback(), ) bookmarks = self._get_bookmarks() self._transaction._begin( diff --git a/src/neo4j/_sync/work/workspace.py b/src/neo4j/_sync/work/workspace.py index d84938fb..8ada7cc5 100644 --- a/src/neo4j/_sync/work/workspace.py +++ b/src/neo4j/_sync/work/workspace.py @@ -18,7 +18,6 @@ import logging import typing as t -from dataclasses import dataclass from ..._async_compat.util import Util from ..._auth_management import to_auth_dict @@ -34,7 +33,10 @@ SessionExpired, ) from .._debug import NonConcurrentMethodChecker -from ..io import AcquireAuth +from ..io import ( + AcquisitionAuth, + AcquisitionDatabase, +) if t.TYPE_CHECKING: @@ -51,12 +53,6 @@ log = logging.getLogger("neo4j") -@dataclass -class _TargetDatabase: - database: str | None - from_cache: bool = False - - class Workspace(NonConcurrentMethodChecker): def __init__(self, pool, config): assert isinstance(config, WorkspaceConfig) @@ -105,21 +101,10 @@ def __enter__(self) -> Workspace: def __exit__(self, exc_type, exc_value, traceback): self.close() - def _make_routing_database_callback( - self, - cache_key: TKey, - ) -> t.Callable[[str], None]: - def _database_callback(database: str | None) -> None: - if not self._pinned_database: - self._set_pinned_database(database) - db_cache: HomeDbCache = self._pool.home_db_cache - db_cache.set(cache_key, database) - - return _database_callback + def _make_db_resolution_callback(self) -> t.Callable[[str], None] | None: + if self._pinned_database: + return None - def _make_query_database_resolution_callback( - self, - ) -> t.Callable[[str], None] | None: def _database_callback(database: str | None) -> None: if not self._pinned_database: self._set_pinned_database(database) @@ -184,7 +169,7 @@ def _update_bookmark(self, bookmark): def _connect(self, access_mode, auth=None, **acquire_kwargs) -> None: acquisition_timeout = self._config.connection_acquisition_timeout force_auth = acquire_kwargs.pop("force_auth", False) - acquire_auth = AcquireAuth(auth, force_auth=force_auth) + acquire_auth = AcquisitionAuth(auth, force_auth=force_auth) if self._connection: # TODO: Investigate this @@ -194,20 +179,21 @@ def _connect(self, access_mode, auth=None, **acquire_kwargs) -> None: self._disconnect() ssr_enabled = self._pool.ssr_enabled - routing_target = self._get_routing_target_database( + target_db = self._get_routing_target_database( acquire_auth, ssr_enabled=ssr_enabled ) acquire_kwargs_ = { "access_mode": access_mode, "timeout": acquisition_timeout, - "database": routing_target.database, + "database": target_db, "bookmarks": self._get_bookmarks(), "auth": acquire_auth, "liveness_check_timeout": None, + "database_callback": self._make_db_resolution_callback(), } acquire_kwargs_.update(acquire_kwargs) self._connection = self._pool.acquire(**acquire_kwargs_) - if routing_target.from_cache and ( + if target_db.guessed and ( not self._pool.ssr_enabled or not self._connection.ssr_enabled ): # race condition: in the meantime, the pool added a connection, @@ -218,25 +204,25 @@ def _connect(self, access_mode, auth=None, **acquire_kwargs) -> None: "falling back to explicit home database resolution", ) self._disconnect() - routing_target = self._get_routing_target_database( + target_db = self._get_routing_target_database( acquire_auth, ssr_enabled=False ) - acquire_kwargs_["database"] = routing_target.database + acquire_kwargs_["database"] = target_db self._connection = self._pool.acquire(**acquire_kwargs_) self._connection_access_mode = access_mode def _get_routing_target_database( self, - acquire_auth: AcquireAuth, + acquire_auth: AcquisitionAuth, ssr_enabled: bool, - ) -> _TargetDatabase: + ) -> AcquisitionDatabase: if self._config.database is not None or self._pool.is_direct_pool: self._set_pinned_database(self._config.database) log.debug( "[#0000] _: routing towards fixed database: %s", self._config.database, ) - return _TargetDatabase(self._config.database) + return AcquisitionDatabase(self._config.database) auth = acquire_auth.auth resolved_auth = self._resolve_session_auth(auth) @@ -257,7 +243,7 @@ def _get_routing_target_database( ), cached_db, ) - return _TargetDatabase(cached_db, from_cache=True) + return AcquisitionDatabase(cached_db, guessed=True) acquisition_timeout = self._config.connection_acquisition_timeout log.debug("[#0000] _: resolve home database") @@ -267,9 +253,9 @@ def _get_routing_target_database( bookmarks=self._get_bookmarks(), auth=acquire_auth, acquisition_timeout=acquisition_timeout, - database_callback=self._make_routing_database_callback(cache_key), + database_callback=self._make_db_resolution_callback(), ) - return _TargetDatabase(self._config.database) + return AcquisitionDatabase(self._config.database) @staticmethod def _resolve_session_auth( diff --git a/testkitbackend/test_config.json b/testkitbackend/test_config.json index bca7f0ca..cbee2bf0 100644 --- a/testkitbackend/test_config.json +++ b/testkitbackend/test_config.json @@ -59,6 +59,7 @@ "Feature:Bolt:5.5": true, "Feature:Bolt:5.6": true, "Feature:Bolt:5.7": true, + "Feature:Bolt:5.8": true, "Feature:Bolt:Patch:UTC": true, "Feature:Impersonation": true, "Feature:TLS:1.1": "Driver blocks TLS 1.1 for security reasons.", diff --git a/tests/unit/async_/io/test_direct.py b/tests/unit/async_/io/test_direct.py index 80014266..c6add31f 100644 --- a/tests/unit/async_/io/test_direct.py +++ b/tests/unit/async_/io/test_direct.py @@ -67,6 +67,7 @@ async def acquire( bookmarks, auth, liveness_check_timeout, + database_callback=None, ): return await self._acquire( self.address, auth, timeout, liveness_check_timeout diff --git a/tests/unit/mixed/io/test_direct.py b/tests/unit/mixed/io/test_direct.py index f330fdfb..d943fccf 100644 --- a/tests/unit/mixed/io/test_direct.py +++ b/tests/unit/mixed/io/test_direct.py @@ -29,9 +29,9 @@ import pytest -from neo4j._async.io._pool import AcquireAuth as AsyncAcquireAuth +from neo4j._async.io._pool import AcquisitionAuth as AsyncAcquisitionAuth from neo4j._deadline import Deadline -from neo4j._sync.io._pool import AcquireAuth +from neo4j._sync.io._pool import AcquisitionAuth from ...async_.io.test_direct import AsyncFakeBoltPool from ...async_.test_auth_management import ( @@ -128,11 +128,11 @@ def acquire_release_conn( def test_full_pool_re_auth(self, fake_connection_generator, mocker): address = ("127.0.0.1", 7687) - acquire_auth1 = AcquireAuth( + acquire_auth1 = AcquisitionAuth( auth=static_auth_manager(("user1", "pass1")) ) auth2 = ("user2", "pass2") - acquire_auth2 = AcquireAuth(auth=static_auth_manager(auth2)) + acquire_auth2 = AcquisitionAuth(auth=static_auth_manager(auth2)) acquire1_event = threading.Event() cx1 = None @@ -243,11 +243,13 @@ async def test_full_pool_re_auth_async( self, async_fake_connection_generator, mocker ): address = ("127.0.0.1", 7687) - acquire_auth1 = AsyncAcquireAuth( + acquire_auth1 = AsyncAcquisitionAuth( auth=static_async_auth_manager(("user1", "pass1")) ) auth2 = ("user2", "pass2") - acquire_auth2 = AsyncAcquireAuth(auth=static_async_auth_manager(auth2)) + acquire_auth2 = AsyncAcquisitionAuth( + auth=static_async_auth_manager(auth2) + ) cx1 = None async def acquire1(pool_): diff --git a/tests/unit/sync/io/test_direct.py b/tests/unit/sync/io/test_direct.py index a899ae49..64b7d9b5 100644 --- a/tests/unit/sync/io/test_direct.py +++ b/tests/unit/sync/io/test_direct.py @@ -67,6 +67,7 @@ def acquire( bookmarks, auth, liveness_check_timeout, + database_callback=None, ): return self._acquire( self.address, auth, timeout, liveness_check_timeout From f0efc47175d251d21504bfd32784f885a4d95fb0 Mon Sep 17 00:00:00 2001 From: Robsdedude Date: Tue, 12 Nov 2024 13:32:14 +0100 Subject: [PATCH 07/26] Optimization: cache basic auth by principal to match impersonation --- src/neo4j/_async/home_db_cache.py | 10 +++++++++- src/neo4j/_sync/home_db_cache.py | 10 +++++++++- testkitbackend/test_config.json | 1 + 3 files changed, 19 insertions(+), 2 deletions(-) diff --git a/src/neo4j/_async/home_db_cache.py b/src/neo4j/_async/home_db_cache.py index 63b6ef16..9ddbca22 100644 --- a/src/neo4j/_async/home_db_cache.py +++ b/src/neo4j/_async/home_db_cache.py @@ -62,7 +62,7 @@ def compute_key( if imp_user is not None: return imp_user if auth is not None: - return _hashable_dict(auth) + return _consolidate_auth_token(auth) return (None,) def get(self, key: TKey) -> str | None: @@ -117,6 +117,14 @@ def enabled(self) -> bool: return self._enabled +def _consolidate_auth_token(auth: dict) -> tuple | str: + if auth.get("scheme") == "basic" and isinstance( + auth.get("principal"), str + ): + return auth["principal"] + return _hashable_dict(auth) + + def _hashable_dict(d: dict) -> tuple: return tuple( (k, _hashable_dict(v) if isinstance(v, dict) else v) diff --git a/src/neo4j/_sync/home_db_cache.py b/src/neo4j/_sync/home_db_cache.py index e56d052f..537642af 100644 --- a/src/neo4j/_sync/home_db_cache.py +++ b/src/neo4j/_sync/home_db_cache.py @@ -62,7 +62,7 @@ def compute_key( if imp_user is not None: return imp_user if auth is not None: - return _hashable_dict(auth) + return _consolidate_auth_token(auth) return (None,) def get(self, key: TKey) -> str | None: @@ -117,6 +117,14 @@ def enabled(self) -> bool: return self._enabled +def _consolidate_auth_token(auth: dict) -> tuple | str: + if auth.get("scheme") == "basic" and isinstance( + auth.get("principal"), str + ): + return auth["principal"] + return _hashable_dict(auth) + + def _hashable_dict(d: dict) -> tuple: return tuple( (k, _hashable_dict(v) if isinstance(v, dict) else v) diff --git a/testkitbackend/test_config.json b/testkitbackend/test_config.json index cbee2bf0..38d2b38c 100644 --- a/testkitbackend/test_config.json +++ b/testkitbackend/test_config.json @@ -71,6 +71,7 @@ "Optimization:ConnectionReuse": true, "Optimization:EagerTransactionBegin": true, "Optimization:ExecuteQueryPipelining": true, + "Optimization:HomeDbCacheBasicPrincipalIsImpersonatedUser": true, "Optimization:ImplicitDefaultArguments": true, "Optimization:MinimalBookmarksSet": true, "Optimization:MinimalResets": true, From 9befd196020bd80ed7edcc3b31cdb4069d7849a0 Mon Sep 17 00:00:00 2001 From: Robsdedude Date: Wed, 13 Nov 2024 09:39:59 +0100 Subject: [PATCH 08/26] Fix driver finding new home db on ROUTE after cache hit --- src/neo4j/_async/io/_pool.py | 22 +++++++++++++++------- src/neo4j/_sync/io/_pool.py | 22 +++++++++++++++------- 2 files changed, 30 insertions(+), 14 deletions(-) diff --git a/src/neo4j/_async/io/_pool.py b/src/neo4j/_async/io/_pool.py index 96a7f234..45ec6371 100644 --- a/src/neo4j/_async/io/_pool.py +++ b/src/neo4j/_async/io/_pool.py @@ -1148,10 +1148,14 @@ async def acquire( from ...api import check_access_mode access_mode = check_access_mode(access_mode) - # await self.ensure_routing_table_is_fresh( - # access_mode=access_mode, database=database, imp_user=None, - # bookmarks=bookmarks, acquisition_timeout=timeout - # ) + + target_database = database.name + + async def wrapped_database_callback(new_database): + nonlocal target_database + if new_database is not None: + target_database = new_database + await AsyncUtil.callback(database_callback, new_database) log.debug( "[#0000] _: acquire routing connection, " @@ -1166,7 +1170,11 @@ async def acquire( bookmarks=bookmarks, auth=auth, acquisition_timeout=timeout, - database_callback=database_callback, + database_callback=( + wrapped_database_callback + if database.guessed + else database_callback + ), ) while True: @@ -1174,7 +1182,7 @@ async def acquire( # Get an address for a connection that have the fewest in-use # connections. address = await self._select_address( - access_mode=access_mode, database=database.name + access_mode=access_mode, database=target_database ) except (ReadServiceUnavailable, WriteServiceUnavailable) as err: raise SessionExpired( @@ -1185,7 +1193,7 @@ async def acquire( log.debug( "[#0000] _: acquire address, database=%r " "address=%r", - database, + target_database, address, ) deadline = Deadline.from_timeout_or_deadline(timeout) diff --git a/src/neo4j/_sync/io/_pool.py b/src/neo4j/_sync/io/_pool.py index 8dbbc164..c8995725 100644 --- a/src/neo4j/_sync/io/_pool.py +++ b/src/neo4j/_sync/io/_pool.py @@ -1145,10 +1145,14 @@ def acquire( from ...api import check_access_mode access_mode = check_access_mode(access_mode) - # await self.ensure_routing_table_is_fresh( - # access_mode=access_mode, database=database, imp_user=None, - # bookmarks=bookmarks, acquisition_timeout=timeout - # ) + + target_database = database.name + + def wrapped_database_callback(new_database): + nonlocal target_database + if new_database is not None: + target_database = new_database + Util.callback(database_callback, new_database) log.debug( "[#0000] _: acquire routing connection, " @@ -1163,7 +1167,11 @@ def acquire( bookmarks=bookmarks, auth=auth, acquisition_timeout=timeout, - database_callback=database_callback, + database_callback=( + wrapped_database_callback + if database.guessed + else database_callback + ), ) while True: @@ -1171,7 +1179,7 @@ def acquire( # Get an address for a connection that have the fewest in-use # connections. address = self._select_address( - access_mode=access_mode, database=database.name + access_mode=access_mode, database=target_database ) except (ReadServiceUnavailable, WriteServiceUnavailable) as err: raise SessionExpired( @@ -1182,7 +1190,7 @@ def acquire( log.debug( "[#0000] _: acquire address, database=%r " "address=%r", - database, + target_database, address, ) deadline = Deadline.from_timeout_or_deadline(timeout) From de391aaec75663b07962e17df5b9fa6922dfe766 Mon Sep 17 00:00:00 2001 From: Robsdedude Date: Wed, 13 Nov 2024 16:32:07 +0100 Subject: [PATCH 09/26] Unit tests for home db cache --- src/neo4j/_async/home_db_cache.py | 41 ++-- src/neo4j/_sync/home_db_cache.py | 41 ++-- tests/unit/async_/test_home_db_cache.py | 239 ++++++++++++++++++++++++ tests/unit/mixed/test_home_db_cache.py | 88 +++++++++ tests/unit/sync/test_home_db_cache.py | 239 ++++++++++++++++++++++++ 5 files changed, 614 insertions(+), 34 deletions(-) create mode 100644 tests/unit/async_/test_home_db_cache.py create mode 100644 tests/unit/mixed/test_home_db_cache.py create mode 100644 tests/unit/sync/test_home_db_cache.py diff --git a/src/neo4j/_async/home_db_cache.py b/src/neo4j/_async/home_db_cache.py index 9ddbca22..bcef66e4 100644 --- a/src/neo4j/_async/home_db_cache.py +++ b/src/neo4j/_async/home_db_cache.py @@ -49,7 +49,12 @@ def __init__( self._ttl = ttl self._cache: dict[TKey, TVal] = {} self._lock = AsyncCooperativeLock() - self._last_clean = monotonic() + self._oldest_entry = monotonic() + if max_size is not None and max_size <= 0: + raise ValueError( + f"home db cache max_size must be greater 0 or None, " + f"got {max_size}" + ) self._max_size = max_size def compute_key( @@ -69,36 +74,40 @@ def get(self, key: TKey) -> str | None: if not self._enabled: return None with self._lock: + self._clean(monotonic()) val = self._cache.get(key) if val is None: return None - now = monotonic() - if now - val[0] > self._ttl: - del self._cache[key] - return None - # Saved some time with a cache hit, - # so we can waste some with cleaning the cache ;) - self._clean(now) return val[1] def set(self, key: TKey, value: str | None) -> None: if not self._enabled: return with self._lock: + now = monotonic() + self._clean(now) if value is None: self._cache.pop(key, None) else: - self._cache[key] = (monotonic(), value) + self._cache[key] = (now, value) def clear(self) -> None: if not self._enabled: return with self._lock: self._cache = {} - self._last_clean = monotonic() + self._oldest_entry = monotonic() - def _clean(self, now: float) -> None: - if self._max_size is not None and len(self._cache) > self._max_size: + def _clean(self, now: float | None = None) -> None: + now = monotonic() if now is None else now + if now - self._oldest_entry > self._ttl: + self._cache = { + k: v for k, v in self._cache.items() if now - v[0] < self._ttl + } + self._oldest_entry = min( + (v[0] for v in self._cache.values()), default=now + ) + if self._max_size and len(self._cache) > self._max_size: self._cache = dict( heapq.nlargest( self._max_size, @@ -106,11 +115,9 @@ def _clean(self, now: float) -> None: key=lambda item: item[1][0], ) ) - if now - self._last_clean > self._ttl: - self._cache = { - k: v for k, v in self._cache.items() if now - v[0] < self._ttl - } - self._last_clean = now + + def __len__(self) -> int: + return len(self._cache) @property def enabled(self) -> bool: diff --git a/src/neo4j/_sync/home_db_cache.py b/src/neo4j/_sync/home_db_cache.py index 537642af..3e6167d4 100644 --- a/src/neo4j/_sync/home_db_cache.py +++ b/src/neo4j/_sync/home_db_cache.py @@ -49,7 +49,12 @@ def __init__( self._ttl = ttl self._cache: dict[TKey, TVal] = {} self._lock = CooperativeLock() - self._last_clean = monotonic() + self._oldest_entry = monotonic() + if max_size is not None and max_size <= 0: + raise ValueError( + f"home db cache max_size must be greater 0 or None, " + f"got {max_size}" + ) self._max_size = max_size def compute_key( @@ -69,36 +74,40 @@ def get(self, key: TKey) -> str | None: if not self._enabled: return None with self._lock: + self._clean(monotonic()) val = self._cache.get(key) if val is None: return None - now = monotonic() - if now - val[0] > self._ttl: - del self._cache[key] - return None - # Saved some time with a cache hit, - # so we can waste some with cleaning the cache ;) - self._clean(now) return val[1] def set(self, key: TKey, value: str | None) -> None: if not self._enabled: return with self._lock: + now = monotonic() + self._clean(now) if value is None: self._cache.pop(key, None) else: - self._cache[key] = (monotonic(), value) + self._cache[key] = (now, value) def clear(self) -> None: if not self._enabled: return with self._lock: self._cache = {} - self._last_clean = monotonic() + self._oldest_entry = monotonic() - def _clean(self, now: float) -> None: - if self._max_size is not None and len(self._cache) > self._max_size: + def _clean(self, now: float | None = None) -> None: + now = monotonic() if now is None else now + if now - self._oldest_entry > self._ttl: + self._cache = { + k: v for k, v in self._cache.items() if now - v[0] < self._ttl + } + self._oldest_entry = min( + (v[0] for v in self._cache.values()), default=now + ) + if self._max_size and len(self._cache) > self._max_size: self._cache = dict( heapq.nlargest( self._max_size, @@ -106,11 +115,9 @@ def _clean(self, now: float) -> None: key=lambda item: item[1][0], ) ) - if now - self._last_clean > self._ttl: - self._cache = { - k: v for k, v in self._cache.items() if now - v[0] < self._ttl - } - self._last_clean = now + + def __len__(self) -> int: + return len(self._cache) @property def enabled(self) -> bool: diff --git a/tests/unit/async_/test_home_db_cache.py b/tests/unit/async_/test_home_db_cache.py new file mode 100644 index 00000000..9b5f33f3 --- /dev/null +++ b/tests/unit/async_/test_home_db_cache.py @@ -0,0 +1,239 @@ +# Copyright (c) "Neo4j" +# Neo4j Sweden AB [https://neo4j.com] +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +from __future__ import annotations + +import typing as t +from datetime import ( + datetime, + timedelta, +) + +import freezegun +import pytest +import pytz + +from neo4j._async.home_db_cache import AsyncHomeDbCache +from neo4j.time import DateTime + + +if t.TYPE_CHECKING: + from neo4j._async.home_db_cache import TKey + + +@pytest.mark.parametrize("enabled", (True, False)) +def test_key_none_is_none(enabled: bool) -> None: + assert AsyncHomeDbCache(enabled=enabled).compute_key(None, None) == (None,) + + +@pytest.mark.parametrize( + "auth", + ( + None, + {}, + {"scheme": "basic", "principal": "neo4j", "credentials": "password"}, + {"scheme": "custom", "principal": "neo4j", "credentials": "password"}, + {"scheme": "custom", "credentials": "nice token"}, + {"foo": "bar"}, + ), +) +@pytest.mark.parametrize("enabled", (True, False)) +def test_key_imp_precedence_over_auth( + auth: dict | None, + enabled: bool, +) -> None: + cache = AsyncHomeDbCache(enabled=enabled) + assert cache.compute_key("bob", auth) == ("bob" if enabled else (None,)) + + +@pytest.mark.parametrize( + "auth", + ( + {}, + {"scheme": "basic", "principal": "neo4j", "credentials": "password"}, + {"scheme": "basic", "principal": "this is wrong, no password?"}, + {"scheme": "basic", "credentials": "this is wrong, no user?"}, + {"scheme": "none"}, + {"scheme": "none", "principal": "even though the scheme is none"}, + {"scheme": "kerberos", "principal": "", "credentials": "ticket"}, + {"scheme": "bearer", "credentials": "nice SSO token"}, + {"scheme": "custom", "principal": "neo4j", "credentials": "password"}, + {"scheme": "custom", "credentials": "bar", "parameters": {"oh": "hi"}}, + {"foo": "bar"}, + ), +) +def test_key_reduces_basic_auth_to_principal(auth: dict) -> None: + key = AsyncHomeDbCache().compute_key(None, auth) + if auth.get("scheme") == "basic" and "principal" in auth: + assert isinstance(key, str) + assert key == auth["principal"] + else: + assert isinstance(key, tuple) + for e in key: + assert isinstance(e, tuple) and len(e) == 2 + assert isinstance(e[0], str) + + +_NAN = float("nan") +_NOW = pytz.timezone("Europe/Stockholm").localize( + DateTime(2021, 8, 12, 12, 34, 57, 123456789) +) + + +@pytest.mark.parametrize( + ("auth1", "auth2"), + ( + ({}, {}), + ({"foo": "bar"}, {"foo": "bar"}), + ({"a": 1, "b": 2}, {"b": 2, "a": 1}), + ( + { + "scheme": "funky", + "credentials": "t0pS3cr3t!!11", + "parameters": { + "how much": 1.5, + # Note: for special values (NaN, temporal types, etc.), + # equality may rely on object identity. + "why": "because", + "difficult": _NAN, + "also difficult 🔥": _NOW, + }, + }, + { + "parameters": { + "also difficult 🔥": _NOW, + "difficult": _NAN, + "why": "because", + "how much": 1.5, + }, + "credentials": "t0pS3cr3t!!11", + "scheme": "funky", + }, + ), + ), +) +def test_key_auth_equality(auth1: dict, auth2: dict) -> None: + cache = AsyncHomeDbCache() + key1 = cache.compute_key(None, auth1) + key2 = cache.compute_key(None, auth2) + + assert len(cache) == 0 + + cache.set(key1, "value") + assert len(cache) == 1 + assert cache.get(key1) == "value" + + cache.set(key2, "value2") + assert len(cache) == 1 + assert cache.get(key1) == "value2" + assert cache.get(key2) == "value2" + + assert key1 == key2 + + +def _assert_entries( + cache: AsyncHomeDbCache, + expected_entries: t.Collection[tuple[TKey, str]], +) -> None: + __tracebackhide__ = True + assert len(cache) == len(expected_entries) + for key, value in expected_entries: + assert cache.get(key) == value + + +def _force_cache_clean( + cache: AsyncHomeDbCache, + now: float | None = None, +) -> None: + cache._clean(now) + + +def test_cache_ttl() -> None: + t0 = datetime(1970, 1, 1) + with freezegun.freeze_time(t0) as time: + cache = AsyncHomeDbCache(ttl=1) + + entries = [] + for i in range(1, 11): + time.move_to(t0 + timedelta(seconds=0.25) * (i - 1)) + + entries.append((cache.compute_key(f"{i}", None), f"value{i}")) + key, value = entries[-1] + cache.set(key, value) + + _force_cache_clean(cache) + _assert_entries(cache, entries) + + time.move_to( + t0 + timedelta(seconds=0.25) * i - timedelta(milliseconds=1) + ) + _force_cache_clean(cache) + _assert_entries(cache, entries) + + time.move_to( + t0 + timedelta(seconds=0.25) * i + timedelta(milliseconds=1) + ) + _force_cache_clean(cache) + entries = entries[-3:] + _assert_entries(cache, entries) + + +def test_cache_ttl_empty_cache() -> None: + t0 = datetime(1970, 1, 1) + with freezegun.freeze_time(t0) as time: + cache = AsyncHomeDbCache(ttl=1) + assert len(cache) == 0 + _force_cache_clean(cache) + assert len(cache) == 0 + + time.move_to(t0 + timedelta(seconds=1, milliseconds=1)) + _force_cache_clean(cache) + assert len(cache) == 0 + + +def test_does_not_return_expired_entries() -> None: + t0 = datetime(1970, 1, 1) + with freezegun.freeze_time(t0) as time: + cache = AsyncHomeDbCache(ttl=1) + key = cache.compute_key("key", None) + value = "value" + + cache.set(cache.compute_key("key", None), value) + assert cache.get(key) == value + + time.move_to(t0 + timedelta(seconds=1, milliseconds=1)) + assert cache.get(key) is None + + +def test_cache_max_size() -> None: + cache = AsyncHomeDbCache(max_size=4) + + entries = [] + for i in range(1, 11): + entries.append((cache.compute_key(f"{i}", None), f"value{i}")) + entries = entries[-4:] + key, value = entries[-1] + cache.set(key, value) + + _force_cache_clean(cache) + _assert_entries(cache, entries) + + +def test_cache_max_size_empty_cache() -> None: + cache = AsyncHomeDbCache(max_size=1) + assert len(cache) == 0 + _force_cache_clean(cache) + assert len(cache) == 0 diff --git a/tests/unit/mixed/test_home_db_cache.py b/tests/unit/mixed/test_home_db_cache.py new file mode 100644 index 00000000..466f1481 --- /dev/null +++ b/tests/unit/mixed/test_home_db_cache.py @@ -0,0 +1,88 @@ +# Copyright (c) "Neo4j" +# Neo4j Sweden AB [https://neo4j.com] +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import random +from collections import defaultdict +from concurrent.futures import ( + as_completed, + ThreadPoolExecutor, +) +from time import monotonic + +from neo4j._sync.home_db_cache import HomeDbCache + + +# No async equivalent exists, because the async home db cache is not really +# async. As there's no IO involved, there's no need for locking in async world. +def test_concurrent_home_db_cache_access() -> None: + workers = 25 + duration = 5 + value_pool_size = 50 + + cache = HomeDbCache(ttl=0.001, max_size=value_pool_size - 2) + keys = tuple( + cache.compute_key(user, None) + for user in map(str, range(1, value_pool_size + 1)) + ) + + def worker(worked_id, end): + non_checks = checks = 0 + + value_counter = defaultdict(int) + while monotonic() < end: + for _ in range(20): # to not check time too often + i = random.randint(0, len(keys) - 1) + value_count = value_counter[i] + key = keys[i] + rand = random.random() + if rand < 0.1: + cache.set(key, None) + res = cache.get(key) + # Never want to read back this worker's own value + assert res is None or not res.startswith(f"{worked_id}-") + elif rand < 0.55: + value_counter[i] += 1 + value = f"{worked_id}-{value_count + 1}" + cache.set(key, value) + res = cache.get(key) + if res is not None and res.startswith(f"{worked_id}-"): + # never want to read back an old value of this worker + checks += 1 + assert res == value + else: + non_checks += 1 + else: + res = cache.get(key) + if res is not None and res.startswith(f"{worked_id}-"): + # never want to read back an old value of this worker + checks += 1 + assert res == f"{worked_id}-{value_count}" + else: + non_checks += 1 + + # import json + # print( + # f"{worked_id}:\n" + # f"{json.dumps(value_counter, indent=2)}\n" + # f"checks: {checks}, non_checks: {non_checks}\n", + # flush=True, + # ) + + with ThreadPoolExecutor(max_workers=workers) as executor: + end = monotonic() + duration + futures = (executor.submit(worker, i, end) for i in range(workers)) + for future in as_completed(futures): + future.result() diff --git a/tests/unit/sync/test_home_db_cache.py b/tests/unit/sync/test_home_db_cache.py new file mode 100644 index 00000000..95b793f5 --- /dev/null +++ b/tests/unit/sync/test_home_db_cache.py @@ -0,0 +1,239 @@ +# Copyright (c) "Neo4j" +# Neo4j Sweden AB [https://neo4j.com] +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +from __future__ import annotations + +import typing as t +from datetime import ( + datetime, + timedelta, +) + +import freezegun +import pytest +import pytz + +from neo4j._sync.home_db_cache import HomeDbCache +from neo4j.time import DateTime + + +if t.TYPE_CHECKING: + from neo4j._sync.home_db_cache import TKey + + +@pytest.mark.parametrize("enabled", (True, False)) +def test_key_none_is_none(enabled: bool) -> None: + assert HomeDbCache(enabled=enabled).compute_key(None, None) == (None,) + + +@pytest.mark.parametrize( + "auth", + ( + None, + {}, + {"scheme": "basic", "principal": "neo4j", "credentials": "password"}, + {"scheme": "custom", "principal": "neo4j", "credentials": "password"}, + {"scheme": "custom", "credentials": "nice token"}, + {"foo": "bar"}, + ), +) +@pytest.mark.parametrize("enabled", (True, False)) +def test_key_imp_precedence_over_auth( + auth: dict | None, + enabled: bool, +) -> None: + cache = HomeDbCache(enabled=enabled) + assert cache.compute_key("bob", auth) == ("bob" if enabled else (None,)) + + +@pytest.mark.parametrize( + "auth", + ( + {}, + {"scheme": "basic", "principal": "neo4j", "credentials": "password"}, + {"scheme": "basic", "principal": "this is wrong, no password?"}, + {"scheme": "basic", "credentials": "this is wrong, no user?"}, + {"scheme": "none"}, + {"scheme": "none", "principal": "even though the scheme is none"}, + {"scheme": "kerberos", "principal": "", "credentials": "ticket"}, + {"scheme": "bearer", "credentials": "nice SSO token"}, + {"scheme": "custom", "principal": "neo4j", "credentials": "password"}, + {"scheme": "custom", "credentials": "bar", "parameters": {"oh": "hi"}}, + {"foo": "bar"}, + ), +) +def test_key_reduces_basic_auth_to_principal(auth: dict) -> None: + key = HomeDbCache().compute_key(None, auth) + if auth.get("scheme") == "basic" and "principal" in auth: + assert isinstance(key, str) + assert key == auth["principal"] + else: + assert isinstance(key, tuple) + for e in key: + assert isinstance(e, tuple) and len(e) == 2 + assert isinstance(e[0], str) + + +_NAN = float("nan") +_NOW = pytz.timezone("Europe/Stockholm").localize( + DateTime(2021, 8, 12, 12, 34, 57, 123456789) +) + + +@pytest.mark.parametrize( + ("auth1", "auth2"), + ( + ({}, {}), + ({"foo": "bar"}, {"foo": "bar"}), + ({"a": 1, "b": 2}, {"b": 2, "a": 1}), + ( + { + "scheme": "funky", + "credentials": "t0pS3cr3t!!11", + "parameters": { + "how much": 1.5, + # Note: for special values (NaN, temporal types, etc.), + # equality may rely on object identity. + "why": "because", + "difficult": _NAN, + "also difficult 🔥": _NOW, + }, + }, + { + "parameters": { + "also difficult 🔥": _NOW, + "difficult": _NAN, + "why": "because", + "how much": 1.5, + }, + "credentials": "t0pS3cr3t!!11", + "scheme": "funky", + }, + ), + ), +) +def test_key_auth_equality(auth1: dict, auth2: dict) -> None: + cache = HomeDbCache() + key1 = cache.compute_key(None, auth1) + key2 = cache.compute_key(None, auth2) + + assert len(cache) == 0 + + cache.set(key1, "value") + assert len(cache) == 1 + assert cache.get(key1) == "value" + + cache.set(key2, "value2") + assert len(cache) == 1 + assert cache.get(key1) == "value2" + assert cache.get(key2) == "value2" + + assert key1 == key2 + + +def _assert_entries( + cache: HomeDbCache, + expected_entries: t.Collection[tuple[TKey, str]], +) -> None: + __tracebackhide__ = True + assert len(cache) == len(expected_entries) + for key, value in expected_entries: + assert cache.get(key) == value + + +def _force_cache_clean( + cache: HomeDbCache, + now: float | None = None, +) -> None: + cache._clean(now) + + +def test_cache_ttl() -> None: + t0 = datetime(1970, 1, 1) + with freezegun.freeze_time(t0) as time: + cache = HomeDbCache(ttl=1) + + entries = [] + for i in range(1, 11): + time.move_to(t0 + timedelta(seconds=0.25) * (i - 1)) + + entries.append((cache.compute_key(f"{i}", None), f"value{i}")) + key, value = entries[-1] + cache.set(key, value) + + _force_cache_clean(cache) + _assert_entries(cache, entries) + + time.move_to( + t0 + timedelta(seconds=0.25) * i - timedelta(milliseconds=1) + ) + _force_cache_clean(cache) + _assert_entries(cache, entries) + + time.move_to( + t0 + timedelta(seconds=0.25) * i + timedelta(milliseconds=1) + ) + _force_cache_clean(cache) + entries = entries[-3:] + _assert_entries(cache, entries) + + +def test_cache_ttl_empty_cache() -> None: + t0 = datetime(1970, 1, 1) + with freezegun.freeze_time(t0) as time: + cache = HomeDbCache(ttl=1) + assert len(cache) == 0 + _force_cache_clean(cache) + assert len(cache) == 0 + + time.move_to(t0 + timedelta(seconds=1, milliseconds=1)) + _force_cache_clean(cache) + assert len(cache) == 0 + + +def test_does_not_return_expired_entries() -> None: + t0 = datetime(1970, 1, 1) + with freezegun.freeze_time(t0) as time: + cache = HomeDbCache(ttl=1) + key = cache.compute_key("key", None) + value = "value" + + cache.set(cache.compute_key("key", None), value) + assert cache.get(key) == value + + time.move_to(t0 + timedelta(seconds=1, milliseconds=1)) + assert cache.get(key) is None + + +def test_cache_max_size() -> None: + cache = HomeDbCache(max_size=4) + + entries = [] + for i in range(1, 11): + entries.append((cache.compute_key(f"{i}", None), f"value{i}")) + entries = entries[-4:] + key, value = entries[-1] + cache.set(key, value) + + _force_cache_clean(cache) + _assert_entries(cache, entries) + + +def test_cache_max_size_empty_cache() -> None: + cache = HomeDbCache(max_size=1) + assert len(cache) == 0 + _force_cache_clean(cache) + assert len(cache) == 0 From 049e22e6fe353fbea7c69a17bb1825aad500b249 Mon Sep 17 00:00:00 2001 From: Robsdedude Date: Thu, 14 Nov 2024 15:28:01 +0100 Subject: [PATCH 10/26] More unit tests --- src/neo4j/_async/work/transaction.py | 4 +- src/neo4j/_sync/work/transaction.py | 4 +- tests/unit/async_/fixtures/fake_connection.py | 18 +- tests/unit/async_/fixtures/fake_pool.py | 4 + tests/unit/async_/io/test_neo4j_pool.py | 291 +++++++++++++----- tests/unit/async_/test_driver.py | 1 + tests/unit/async_/work/test_result.py | 80 +++-- tests/unit/async_/work/test_session.py | 183 ++++++++++- tests/unit/async_/work/test_transaction.py | 77 ++++- tests/unit/sync/fixtures/fake_connection.py | 18 +- tests/unit/sync/fixtures/fake_pool.py | 4 + tests/unit/sync/io/test_neo4j_pool.py | 291 +++++++++++++----- tests/unit/sync/test_driver.py | 1 + tests/unit/sync/work/test_result.py | 80 +++-- tests/unit/sync/work/test_session.py | 183 ++++++++++- tests/unit/sync/work/test_transaction.py | 77 ++++- 16 files changed, 1086 insertions(+), 230 deletions(-) diff --git a/src/neo4j/_async/work/transaction.py b/src/neo4j/_async/work/transaction.py index 4ce9937f..921a29ef 100644 --- a/src/neo4j/_async/work/transaction.py +++ b/src/neo4j/_async/work/transaction.py @@ -108,7 +108,9 @@ async def on_begin_success(metadata_): db=database, imp_user=imp_user, notifications_min_severity=notifications_min_severity, - notifications_disabled_classifications=notifications_disabled_classifications, + notifications_disabled_classifications=( + notifications_disabled_classifications + ), on_success=on_begin_success, ) if not pipelined: diff --git a/src/neo4j/_sync/work/transaction.py b/src/neo4j/_sync/work/transaction.py index c0a270f9..f8a6e461 100644 --- a/src/neo4j/_sync/work/transaction.py +++ b/src/neo4j/_sync/work/transaction.py @@ -108,7 +108,9 @@ def on_begin_success(metadata_): db=database, imp_user=imp_user, notifications_min_severity=notifications_min_severity, - notifications_disabled_classifications=notifications_disabled_classifications, + notifications_disabled_classifications=( + notifications_disabled_classifications + ), on_success=on_begin_success, ) if not pipelined: diff --git a/tests/unit/async_/fixtures/fake_connection.py b/tests/unit/async_/fixtures/fake_connection.py index 9bf96779..98c40df6 100644 --- a/tests/unit/async_/fixtures/fake_connection.py +++ b/tests/unit/async_/fixtures/fake_connection.py @@ -127,7 +127,14 @@ async def callback(): return func method_mock = parent.__getattr__(name) - if name in {"run", "commit", "pull", "rollback", "discard"}: + if name in { + "begin", + "run", + "commit", + "pull", + "rollback", + "discard", + }: method_mock.side_effect = build_message_handler(name) return method_mock @@ -218,7 +225,14 @@ async def callback(): return func method_mock = parent.__getattr__(name) - if name in {"run", "commit", "pull", "rollback", "discard"}: + if name in { + "begin", + "run", + "commit", + "pull", + "rollback", + "discard", + }: method_mock.side_effect = build_message_handler(name) return method_mock diff --git a/tests/unit/async_/fixtures/fake_pool.py b/tests/unit/async_/fixtures/fake_pool.py index 877c22fd..892d7b0d 100644 --- a/tests/unit/async_/fixtures/fake_pool.py +++ b/tests/unit/async_/fixtures/fake_pool.py @@ -17,6 +17,7 @@ import pytest from neo4j._async.config import AsyncPoolConfig +from neo4j._async.home_db_cache import AsyncHomeDbCache from neo4j._async.io._pool import AsyncIOPool @@ -32,6 +33,9 @@ def async_fake_pool(async_fake_connection_generator, mocker): pool.buffered_connection_mocks = [] pool.acquired_connection_mocks = [] pool.pool_config = AsyncPoolConfig() + pool.ssr_enabled = False + pool.is_direct_pool = True + pool.home_db_cache = AsyncHomeDbCache(enabled=False) def acquire_side_effect(*_, **__): if pool.buffered_connection_mocks: diff --git a/tests/unit/async_/io/test_neo4j_pool.py b/tests/unit/async_/io/test_neo4j_pool.py index c0be16ad..e1549fd0 100644 --- a/tests/unit/async_/io/test_neo4j_pool.py +++ b/tests/unit/async_/io/test_neo4j_pool.py @@ -24,6 +24,7 @@ ) from neo4j._async.config import AsyncPoolConfig from neo4j._async.io import ( + AcquisitionDatabase, AsyncBolt, AsyncNeo4jPool, ) @@ -53,11 +54,27 @@ WRITER1_ADDRESS = ResolvedAddress(("1.2.3.1", 9020), host_name="host") +def make_home_db_resolve(home_db): + def _home_db_resolve(db): + return db or home_db + + return _home_db_resolve + + +_default_db_resolve = make_home_db_resolve("neo4j") + + @pytest.fixture def custom_routing_opener(async_fake_connection_generator, mocker): - def make_opener(failures=None, get_readers=None): + def make_opener( + failures=None, + get_readers=None, + db_resolve=_default_db_resolve, + on_open=None, + ): def routing_side_effect(*args, **kwargs): nonlocal failures + opener_.route_requests.append(kwargs.get("database")) res = next(failures, None) if res is None: routers = [ @@ -70,16 +87,18 @@ def routing_side_effect(*args, **kwargs): else: readers = [str(READER1_ADDRESS)] writers = [str(WRITER1_ADDRESS)] - return [ - { - "ttl": 1000, - "servers": [ - {"addresses": routers, "role": "ROUTE"}, - {"addresses": readers, "role": "READ"}, - {"addresses": writers, "role": "WRITE"}, - ], - } - ] + rt = { + "ttl": 1000, + "servers": [ + {"addresses": routers, "role": "ROUTE"}, + {"addresses": readers, "role": "READ"}, + {"addresses": writers, "role": "WRITE"}, + ], + } + db = db_resolve(kwargs.get("database")) + if db is not ...: + rt["db"] = db + return [rt] raise res async def open_(addr, auth, timeout): @@ -92,11 +111,16 @@ async def open_(addr, auth, timeout): route_mock.side_effect = routing_side_effect connection.attach_mock(route_mock, "route") opener_.connections.append(connection) + + if callable(on_open): + on_open(connection) + return connection failures = iter(failures or []) opener_ = mocker.AsyncMock() opener_.connections = [] + opener_.route_requests = [] opener_.side_effect = open_ return opener_ @@ -124,54 +148,101 @@ def _simple_pool(opener) -> AsyncNeo4jPool: ) +TEST_DB1 = AcquisitionDatabase("test_db1") +TEST_DB2 = AcquisitionDatabase("test_db2") + + +@pytest.mark.parametrize("guessed_db", (True, False)) @mark_async_test -async def test_acquires_new_routing_table_if_deleted(opener): +async def test_acquires_new_routing_table_if_deleted( + custom_routing_opener, + guessed_db, +) -> None: + db = AcquisitionDatabase("test_db", guessed=guessed_db) + opener = custom_routing_opener(db_resolve=make_home_db_resolve(db.name)) pool = _simple_pool(opener) - cx = await pool.acquire(READ_ACCESS, 30, "test_db", None, None, None) + cx = await pool.acquire(READ_ACCESS, 30, db, None, None, None) await pool.release(cx) - assert pool.routing_tables.get("test_db") + assert pool.routing_tables.get(db.name) + assert opener.route_requests == [None if guessed_db else db.name] + opener.route_requests = [] - del pool.routing_tables["test_db"] + del pool.routing_tables[db.name] - cx = await pool.acquire(READ_ACCESS, 30, "test_db", None, None, None) + cx = await pool.acquire(READ_ACCESS, 30, db, None, None, None) await pool.release(cx) - assert pool.routing_tables.get("test_db") + assert pool.routing_tables.get(db.name) + assert opener.route_requests == [None if guessed_db else db.name] +@pytest.mark.parametrize("guessed_db", (True, False)) @mark_async_test -async def test_acquires_new_routing_table_if_stale(opener): +async def test_acquires_new_routing_table_if_stale( + custom_routing_opener, + guessed_db, +) -> None: + db = AcquisitionDatabase("test_db", guessed=guessed_db) + opener = custom_routing_opener(db_resolve=make_home_db_resolve(db.name)) pool = _simple_pool(opener) - cx = await pool.acquire(READ_ACCESS, 30, "test_db", None, None, None) + cx = await pool.acquire(READ_ACCESS, 30, db, None, None, None) await pool.release(cx) - assert pool.routing_tables.get("test_db") + assert pool.routing_tables.get(db.name) + assert opener.route_requests == [None if guessed_db else db.name] + opener.route_requests = [] - old_value = pool.routing_tables["test_db"].last_updated_time - pool.routing_tables["test_db"].ttl = 0 + old_value = pool.routing_tables[db.name].last_updated_time + pool.routing_tables[db.name].ttl = 0 - cx = await pool.acquire(READ_ACCESS, 30, "test_db", None, None, None) + cx = await pool.acquire(READ_ACCESS, 30, db, None, None, None) await pool.release(cx) - assert pool.routing_tables["test_db"].last_updated_time > old_value + assert pool.routing_tables[db.name].last_updated_time > old_value + assert opener.route_requests == [None if guessed_db else db.name] @mark_async_test async def test_removes_old_routing_table(opener): pool = _simple_pool(opener) - cx = await pool.acquire(READ_ACCESS, 30, "test_db1", None, None, None) + cx = await pool.acquire(READ_ACCESS, 30, TEST_DB1, None, None, None) await pool.release(cx) - assert pool.routing_tables.get("test_db1") - cx = await pool.acquire(READ_ACCESS, 30, "test_db2", None, None, None) + assert pool.routing_tables.get(TEST_DB1.name) + cx = await pool.acquire(READ_ACCESS, 30, TEST_DB2, None, None, None) await pool.release(cx) - assert pool.routing_tables.get("test_db2") + assert pool.routing_tables.get(TEST_DB2.name) - old_value = pool.routing_tables["test_db1"].last_updated_time - pool.routing_tables["test_db1"].ttl = 0 - db2_rt = pool.routing_tables["test_db2"] + old_value = pool.routing_tables[TEST_DB1.name].last_updated_time + pool.routing_tables[TEST_DB1.name].ttl = 0 + db2_rt = pool.routing_tables[TEST_DB2.name] db2_rt.ttl = -RoutingConfig.routing_table_purge_delay - cx = await pool.acquire(READ_ACCESS, 30, "test_db1", None, None, None) + cx = await pool.acquire(READ_ACCESS, 30, TEST_DB1, None, None, None) await pool.release(cx) - assert pool.routing_tables["test_db1"].last_updated_time > old_value - assert "test_db2" not in pool.routing_tables + assert pool.routing_tables[TEST_DB1.name].last_updated_time > old_value + assert TEST_DB2.name not in pool.routing_tables + + +@pytest.mark.parametrize("guessed_db", (True, False)) +@mark_async_test +async def test_db_resolution_callback(custom_routing_opener, guessed_db): + cb_calls = [] + + def cb(db_): + nonlocal cb_calls + cb_calls.append(db_) + + db = AcquisitionDatabase("test_db", guessed=guessed_db) + home_db = "home_db" + expected_target_db = home_db if db.guessed else db.name + + opener = custom_routing_opener(db_resolve=make_home_db_resolve(home_db)) + pool = _simple_pool(opener) + cx = await pool.acquire( + READ_ACCESS, 30, db, None, None, None, database_callback=cb + ) + await pool.release(cx) + + assert pool.routing_tables.get(expected_target_db) + assert opener.route_requests == [None if guessed_db else db.name] + assert cb_calls == [expected_target_db] @pytest.mark.parametrize("type_", ("r", "w")) @@ -181,7 +252,7 @@ async def test_chooses_right_connection_type(opener, type_): cx1 = await pool.acquire( READ_ACCESS if type_ == "r" else WRITE_ACCESS, 30, - "test_db", + TEST_DB1, None, None, None, @@ -196,9 +267,9 @@ async def test_chooses_right_connection_type(opener, type_): @mark_async_test async def test_reuses_connection(opener): pool = _simple_pool(opener) - cx1 = await pool.acquire(READ_ACCESS, 30, "test_db", None, None, None) + cx1 = await pool.acquire(READ_ACCESS, 30, TEST_DB1, None, None, None) await pool.release(cx1) - cx2 = await pool.acquire(READ_ACCESS, 30, "test_db", None, None, None) + cx2 = await pool.acquire(READ_ACCESS, 30, TEST_DB1, None, None, None) assert cx1 is cx2 @@ -216,7 +287,7 @@ async def break_connection(): return None pool = _simple_pool(opener) - cx1 = await pool.acquire(READ_ACCESS, 30, "test_db", None, None, None) + cx1 = await pool.acquire(READ_ACCESS, 30, TEST_DB1, None, None, None) await pool.release(cx1) assert cx1 in pool.connections[cx1.unresolved_address] # simulate connection going stale (e.g. exceeding idle timeout) and then @@ -226,7 +297,7 @@ async def break_connection(): if break_on_close: cx_close_mock_side_effect = cx_close_mock.side_effect cx_close_mock.side_effect = break_connection - cx2 = await pool.acquire(READ_ACCESS, 30, "test_db", None, None, None) + cx2 = await pool.acquire(READ_ACCESS, 30, TEST_DB1, None, None, None) await pool.release(cx2) if break_on_close: cx1.close.assert_called() @@ -241,12 +312,12 @@ async def break_connection(): @mark_async_test async def test_does_not_close_stale_connections_in_use(opener): pool = _simple_pool(opener) - cx1 = await pool.acquire(READ_ACCESS, 30, "test_db", None, None, None) + cx1 = await pool.acquire(READ_ACCESS, 30, TEST_DB1, None, None, None) assert cx1 in pool.connections[cx1.unresolved_address] # simulate connection going stale (e.g. exceeding idle timeout) while being # in use cx1.stale.return_value = True - cx2 = await pool.acquire(READ_ACCESS, 30, "test_db", None, None, None) + cx2 = await pool.acquire(READ_ACCESS, 30, TEST_DB1, None, None, None) await pool.release(cx2) cx1.close.assert_not_called() assert cx2 is not cx1 @@ -259,7 +330,7 @@ async def test_does_not_close_stale_connections_in_use(opener): # it should be closed when trying to acquire the next connection cx1.close.assert_not_called() - cx3 = await pool.acquire(READ_ACCESS, 30, "test_db", None, None, None) + cx3 = await pool.acquire(READ_ACCESS, 30, TEST_DB1, None, None, None) await pool.release(cx3) cx1.close.assert_called_once() assert cx2 is cx3 @@ -271,7 +342,7 @@ async def test_does_not_close_stale_connections_in_use(opener): @mark_async_test async def test_release_resets_connections(opener): pool = _simple_pool(opener) - cx1 = await pool.acquire(READ_ACCESS, 30, "test_db", None, None, None) + cx1 = await pool.acquire(READ_ACCESS, 30, TEST_DB1, None, None, None) cx1.is_reset_mock.return_value = False cx1.is_reset_mock.reset_mock() await pool.release(cx1) @@ -282,7 +353,7 @@ async def test_release_resets_connections(opener): @mark_async_test async def test_release_does_not_resets_closed_connections(opener): pool = _simple_pool(opener) - cx1 = await pool.acquire(READ_ACCESS, 30, "test_db", None, None, None) + cx1 = await pool.acquire(READ_ACCESS, 30, TEST_DB1, None, None, None) cx1.closed.return_value = True cx1.closed.reset_mock() cx1.is_reset_mock.reset_mock() @@ -295,7 +366,7 @@ async def test_release_does_not_resets_closed_connections(opener): @mark_async_test async def test_release_does_not_resets_defunct_connections(opener): pool = _simple_pool(opener) - cx1 = await pool.acquire(READ_ACCESS, 30, "test_db", None, None, None) + cx1 = await pool.acquire(READ_ACCESS, 30, TEST_DB1, None, None, None) cx1.defunct.return_value = True cx1.defunct.reset_mock() cx1.is_reset_mock.reset_mock() @@ -457,8 +528,8 @@ async def close_side_effect(): # create pool with 2 idle connections pool = _simple_pool(opener) - cx1 = await pool.acquire(READ_ACCESS, 30, "test_db", None, None, None) - cx2 = await pool.acquire(READ_ACCESS, 30, "test_db", None, None, None) + cx1 = await pool.acquire(READ_ACCESS, 30, TEST_DB1, None, None, None) + cx2 = await pool.acquire(READ_ACCESS, 30, TEST_DB1, None, None, None) await pool.release(cx1) await pool.release(cx2) @@ -470,7 +541,7 @@ async def close_side_effect(): # unreachable cx1.stale.return_value = True - cx3 = await pool.acquire(READ_ACCESS, 30, "test_db", None, None, None) + cx3 = await pool.acquire(READ_ACCESS, 30, TEST_DB1, None, None, None) assert cx3 is not cx1 assert cx3 is not cx2 @@ -479,11 +550,11 @@ async def close_side_effect(): @mark_async_test async def test_failing_opener_leaves_connections_in_use_alone(opener): pool = _simple_pool(opener) - cx1 = await pool.acquire(READ_ACCESS, 30, "test_db", None, None, None) + cx1 = await pool.acquire(READ_ACCESS, 30, TEST_DB1, None, None, None) opener.side_effect = ServiceUnavailable("Server overloaded") with pytest.raises((ServiceUnavailable, SessionExpired)): - await pool.acquire(READ_ACCESS, 30, "test_db", None, None, None) + await pool.acquire(READ_ACCESS, 30, TEST_DB1, None, None, None) assert not cx1.closed() @@ -505,7 +576,7 @@ async def test__acquire_new_later_without_room(opener): config = _pool_config() config.max_connection_pool_size = 1 pool = AsyncNeo4jPool(opener, config, WorkspaceConfig(), ROUTER1_ADDRESS) - _ = await pool.acquire(READ_ACCESS, 30, "test_db", None, None, None) + _ = await pool.acquire(READ_ACCESS, 30, TEST_DB1, None, None, None) # pool is full now assert pool.connections_reservations[READER1_ADDRESS] == 0 creator = pool._acquire_new_later(READER1_ADDRESS, None, Deadline(1)) @@ -559,13 +630,13 @@ async def test_discovery_is_retried(custom_routing_opener, error): WorkspaceConfig(), ResolvedAddress(("1.2.3.1", 9999), host_name="host"), ) - cx1 = await pool.acquire(READ_ACCESS, 30, "test_db", None, None, None) + cx1 = await pool.acquire(READ_ACCESS, 30, TEST_DB1, None, None, None) await pool.release(cx1) - pool.routing_tables.get("test_db").ttl = 0 + pool.routing_tables.get(TEST_DB1.name).ttl = 0 - cx2 = await pool.acquire(READ_ACCESS, 30, "test_db", None, None, None) + cx2 = await pool.acquire(READ_ACCESS, 30, TEST_DB1, None, None, None) await pool.release(cx2) - assert pool.routing_tables.get("test_db") + assert pool.routing_tables.get(TEST_DB1.name) assert cx1 is cx2 @@ -611,12 +682,12 @@ async def test_fast_failing_discovery(custom_routing_opener, error): WorkspaceConfig(), ResolvedAddress(("1.2.3.1", 9999), host_name="host"), ) - cx1 = await pool.acquire(READ_ACCESS, 30, "test_db", None, None, None) + cx1 = await pool.acquire(READ_ACCESS, 30, TEST_DB1, None, None, None) await pool.release(cx1) - pool.routing_tables.get("test_db").ttl = 0 + pool.routing_tables.get(TEST_DB1.name).ttl = 0 with pytest.raises(error.__class__) as exc: - await pool.acquire(READ_ACCESS, 30, "test_db", None, None, None) + await pool.acquire(READ_ACCESS, 30, TEST_DB1, None, None, None) assert exc.value is error @@ -657,11 +728,11 @@ async def test_connection_error_callback( config.auth = auth_manager pool = AsyncNeo4jPool(opener, config, WorkspaceConfig(), ROUTER1_ADDRESS) cxs_read = [ - await pool.acquire(READ_ACCESS, 30, "test_db", None, None, None) + await pool.acquire(READ_ACCESS, 30, TEST_DB1, None, None, None) for _ in range(5) ] cxs_write = [ - await pool.acquire(WRITE_ACCESS, 30, "test_db", None, None, None) + await pool.acquire(WRITE_ACCESS, 30, TEST_DB1, None, None, None) for _ in range(5) ] @@ -690,7 +761,7 @@ async def test_connection_error_callback( @mark_async_test async def test_pool_closes_connections_dropped_from_rt(custom_routing_opener): - readers = {"db1": [str(READER1_ADDRESS)]} + readers = {TEST_DB1.name: [str(READER1_ADDRESS)]} def get_readers(database): return readers[database] @@ -700,7 +771,7 @@ def get_readers(database): pool = AsyncNeo4jPool( opener, _pool_config(), WorkspaceConfig(), ROUTER1_ADDRESS ) - cx1 = await pool.acquire(READ_ACCESS, 30, "db1", None, None, None) + cx1 = await pool.acquire(READ_ACCESS, 30, TEST_DB1, None, None, None) assert cx1.unresolved_address == READER1_ADDRESS await pool.release(cx1) @@ -708,10 +779,10 @@ def get_readers(database): assert len(pool.connections[READER1_ADDRESS]) == 1 # force RT refresh, returning a different reader - del pool.routing_tables["db1"] - readers["db1"] = [str(READER2_ADDRESS)] + del pool.routing_tables[TEST_DB1.name] + readers[TEST_DB1.name] = [str(READER2_ADDRESS)] - cx2 = await pool.acquire(READ_ACCESS, 30, "db1", None, None, None) + cx2 = await pool.acquire(READ_ACCESS, 30, TEST_DB1, None, None, None) assert cx2.unresolved_address == READER2_ADDRESS cx1.close.assert_awaited_once() @@ -726,8 +797,8 @@ async def test_pool_does_not_close_connections_dropped_from_rt_for_other_server( custom_routing_opener, ): readers = { - "db1": [str(READER1_ADDRESS), str(READER2_ADDRESS)], - "db2": [str(READER1_ADDRESS)], + TEST_DB1.name: [str(READER1_ADDRESS), str(READER2_ADDRESS)], + TEST_DB2.name: [str(READER1_ADDRESS)], } def get_readers(database): @@ -738,14 +809,14 @@ def get_readers(database): pool = AsyncNeo4jPool( opener, _pool_config(), WorkspaceConfig(), ROUTER1_ADDRESS ) - cx1 = await pool.acquire(READ_ACCESS, 30, "db1", None, None, None) + cx1 = await pool.acquire(READ_ACCESS, 30, TEST_DB1, None, None, None) await pool.release(cx1) assert cx1.unresolved_address in {READER1_ADDRESS, READER2_ADDRESS} reader1_connection_count = len(pool.connections[READER1_ADDRESS]) reader2_connection_count = len(pool.connections[READER2_ADDRESS]) assert reader1_connection_count + reader2_connection_count == 1 - cx2 = await pool.acquire(READ_ACCESS, 30, "db2", None, None, None) + cx2 = await pool.acquire(READ_ACCESS, 30, TEST_DB2, None, None, None) await pool.release(cx2) assert cx2.unresolved_address == READER1_ADDRESS cx1.close.assert_not_called() @@ -754,10 +825,10 @@ def get_readers(database): assert len(pool.connections[READER2_ADDRESS]) == reader2_connection_count # force RT refresh, returning a different reader - del pool.routing_tables["db2"] - readers["db2"] = [str(READER3_ADDRESS)] + del pool.routing_tables[TEST_DB2.name] + readers[TEST_DB2.name] = [str(READER3_ADDRESS)] - cx3 = await pool.acquire(READ_ACCESS, 30, "db2", None, None, None) + cx3 = await pool.acquire(READ_ACCESS, 30, TEST_DB2, None, None, None) await pool.release(cx3) assert cx3.unresolved_address == READER3_ADDRESS @@ -767,3 +838,79 @@ def get_readers(database): assert len(pool.connections[READER1_ADDRESS]) == 1 assert len(pool.connections[READER2_ADDRESS]) == reader2_connection_count assert len(pool.connections[READER3_ADDRESS]) == 1 + + +@mark_async_test +async def test_tracks_ssr_connection_hints(custom_routing_opener): + connection_count = 0 + + def on_open(connection): + if connection.unresolved_address in { + ROUTER1_ADDRESS, + ROUTER2_ADDRESS, + ROUTER3_ADDRESS, + }: + connection.ssr_enabled = True + return + nonlocal connection_count + connection_count += 1 + connection.ssr_enabled = connection_count != 2 + + opener = custom_routing_opener(on_open=on_open) + pool = AsyncNeo4jPool( + opener, _pool_config(), WorkspaceConfig(), ROUTER1_ADDRESS + ) + + # no connection in pool => cannot know => defensive assumption: off + assert not pool.ssr_enabled + + # open 1st reader connection (supports SSR) + cx1 = await pool.acquire(READ_ACCESS, 30, TEST_DB1, None, None, None) + assert cx1.ssr_enabled # double check we got the mocking right + + assert pool.ssr_enabled + + # open 2nd reader connection (does not support SSR) + cx2 = await pool.acquire(READ_ACCESS, 30, TEST_DB1, None, None, None) + assert not cx2.ssr_enabled # double check we got the mocking right + + assert not pool.ssr_enabled + + # open 3rd reader connection (supports SSR) + cx3 = await pool.acquire(READ_ACCESS, 30, TEST_DB1, None, None, None) + assert cx3.ssr_enabled # double check we got the mocking right + + assert not pool.ssr_enabled + + await pool.release(cx1) + await pool.release(cx2) + await pool.release(cx3) + + assert not pool.ssr_enabled + + cxs = [ + await pool.acquire(READ_ACCESS, 30, TEST_DB1, None, None, None) + for _ in range(3) + ] + assert sum(not c.ssr_enabled for c in cxs) == 1 # double check + + for cx in (cx for cx in cxs if not cx.ssr_enabled): + await cx.close() + + # after the single connection without SSR support is closed + for cx in cxs: + await pool.release(cx) + + # force pool cleaning up all stale connections: + cxs = [ + await pool.acquire(READ_ACCESS, 30, TEST_DB1, None, None, None) + for _ in range(3) + ] + assert all(cx.ssr_enabled for cx in cxs) # double check + + assert pool.ssr_enabled + + for cx in cxs: + await pool.release(cx) + + assert pool.ssr_enabled diff --git a/tests/unit/async_/test_driver.py b/tests/unit/async_/test_driver.py index 53d28710..9e7b657b 100644 --- a/tests/unit/async_/test_driver.py +++ b/tests/unit/async_/test_driver.py @@ -277,6 +277,7 @@ async def test_driver_opens_write_session_by_default( bookmarks=mocker.ANY, auth=mocker.ANY, liveness_check_timeout=mocker.ANY, + database_callback=mocker.ANY, ) tx._begin.assert_awaited_once_with( mocker.ANY, diff --git a/tests/unit/async_/work/test_result.py b/tests/unit/async_/work/test_result.py index 6141b998..01e05443 100644 --- a/tests/unit/async_/work/test_result.py +++ b/tests/unit/async_/work/test_result.py @@ -315,7 +315,7 @@ async def fetch_and_compare_all_records( @mark_async_test async def test_result_iteration(method, records): connection = AsyncConnectionStub(records=Records(["x"], records)) - result = AsyncResult(connection, 2, None, noop, noop) + result = AsyncResult(connection, 2, None, noop, noop, None) await result._run("CYPHER", {}, None, None, "r", None, None, None) await fetch_and_compare_all_records(result, "x", records, method) @@ -324,7 +324,7 @@ async def test_result_iteration(method, records): async def test_result_iteration_mixed_methods(): records = [[i] for i in range(10)] connection = AsyncConnectionStub(records=Records(["x"], records)) - result = AsyncResult(connection, 4, None, noop, noop) + result = AsyncResult(connection, 4, None, noop, noop, None) await result._run("CYPHER", {}, None, None, "r", None, None, None) iter1 = AsyncUtil.iter(result) iter2 = AsyncUtil.iter(result) @@ -372,9 +372,9 @@ async def test_parallel_result_iteration(method, invert_fetch): connection = AsyncConnectionStub( records=(Records(["x"], records1), Records(["x"], records2)) ) - result1 = AsyncResult(connection, 2, None, noop, noop) + result1 = AsyncResult(connection, 2, None, noop, noop, None) await result1._run("CYPHER1", {}, None, None, "r", None, None, None) - result2 = AsyncResult(connection, 2, None, noop, noop) + result2 = AsyncResult(connection, 2, None, noop, noop, None) await result2._run("CYPHER2", {}, None, None, "r", None, None, None) if invert_fetch: await fetch_and_compare_all_records(result2, "x", records2, method) @@ -395,9 +395,9 @@ async def test_interwoven_result_iteration(method, invert_fetch): connection = AsyncConnectionStub( records=(Records(["x"], records1), Records(["y"], records2)) ) - result1 = AsyncResult(connection, 2, None, noop, noop) + result1 = AsyncResult(connection, 2, None, noop, noop, None) await result1._run("CYPHER1", {}, None, None, "r", None, None, None) - result2 = AsyncResult(connection, 2, None, noop, noop) + result2 = AsyncResult(connection, 2, None, noop, noop, None) await result2._run("CYPHER2", {}, None, None, "r", None, None, None) start = 0 for n in (1, 2, 3, 1, None): @@ -424,7 +424,7 @@ async def test_interwoven_result_iteration(method, invert_fetch): @mark_async_test async def test_result_peek(records, fetch_size): connection = AsyncConnectionStub(records=Records(["x"], records)) - result = AsyncResult(connection, fetch_size, None, noop, noop) + result = AsyncResult(connection, fetch_size, None, noop, noop, None) await result._run("CYPHER", {}, None, None, "r", None, None, None) for i in range(len(records) + 1): record = await result.peek() @@ -447,7 +447,7 @@ async def test_result_single_non_strict(records, fetch_size, default): kwargs["strict"] = False connection = AsyncConnectionStub(records=Records(["x"], records)) - result = AsyncResult(connection, fetch_size, None, noop, noop) + result = AsyncResult(connection, fetch_size, None, noop, noop, None) await result._run("CYPHER", {}, None, None, "r", None, None, None) if len(records) == 0: assert await result.single(**kwargs) is None @@ -466,7 +466,7 @@ async def test_result_single_non_strict(records, fetch_size, default): @mark_async_test async def test_result_single_strict(records, fetch_size): connection = AsyncConnectionStub(records=Records(["x"], records)) - result = AsyncResult(connection, fetch_size, None, noop, noop) + result = AsyncResult(connection, fetch_size, None, noop, noop, None) await result._run("CYPHER", {}, None, None, "r", None, None, None) if len(records) != 1: with pytest.raises(ResultNotSingleError) as exc: @@ -490,7 +490,7 @@ async def test_result_single_strict(records, fetch_size): @mark_async_test async def test_result_single_exhausts_records(records, fetch_size, strict): connection = AsyncConnectionStub(records=Records(["x"], records)) - result = AsyncResult(connection, fetch_size, None, noop, noop) + result = AsyncResult(connection, fetch_size, None, noop, noop, None) await result._run("CYPHER", {}, None, None, "r", None, None, None) try: with warnings.catch_warnings(): @@ -512,7 +512,7 @@ async def test_result_single_exhausts_records(records, fetch_size, strict): @mark_async_test async def test_result_fetch(records, fetch_size, strict): connection = AsyncConnectionStub(records=Records(["x"], records)) - result = AsyncResult(connection, fetch_size, None, noop, noop) + result = AsyncResult(connection, fetch_size, None, noop, noop, None) await result._run("CYPHER", {}, None, None, "r", None, None, None) assert await result.fetch(0) == [] assert await result.fetch(-1) == [] @@ -524,7 +524,7 @@ async def test_result_fetch(records, fetch_size, strict): @mark_async_test async def test_keys_are_available_before_and_after_stream(): connection = AsyncConnectionStub(records=Records(["x"], [[1], [2]])) - result = AsyncResult(connection, 1, None, noop, noop) + result = AsyncResult(connection, 1, None, noop, noop, None) await result._run("CYPHER", {}, None, None, "r", None, None, None) assert list(result.keys()) == ["x"] await AsyncUtil.list(result) @@ -540,7 +540,7 @@ async def test_consume(records, consume_one, summary_meta, consume_times): connection = AsyncConnectionStub( records=Records(["x"], records), summary_meta=summary_meta ) - result = AsyncResult(connection, 1, None, noop, noop) + result = AsyncResult(connection, 1, None, noop, noop, None) await result._run("CYPHER", {}, None, None, "r", None, None, None) if consume_one: with suppress(StopAsyncIteration): @@ -574,7 +574,7 @@ async def test_time_in_summary(t_first, t_last): summary_meta=summary_meta, ) - result = AsyncResult(connection, 1, None, noop, noop) + result = AsyncResult(connection, 1, None, noop, noop, None) await result._run("CYPHER", {}, None, None, "r", None, None, None) summary = await result.consume() @@ -596,7 +596,7 @@ async def test_time_in_summary(t_first, t_last): async def test_counts_in_summary(): connection = AsyncConnectionStub(records=Records(["n"], [[1], [2]])) - result = AsyncResult(connection, 1, None, noop, noop) + result = AsyncResult(connection, 1, None, noop, noop, None) await result._run("CYPHER", {}, None, None, "r", None, None, None) summary = await result.consume() @@ -610,7 +610,7 @@ async def test_query_type(query_type): records=Records(["n"], [[1], [2]]), summary_meta={"type": query_type} ) - result = AsyncResult(connection, 1, None, noop, noop) + result = AsyncResult(connection, 1, None, noop, noop, None) await result._run("CYPHER", {}, None, None, "r", None, None, None) summary = await result.consume() @@ -625,7 +625,7 @@ async def test_data(num_records): records=Records(["n"], [[i + 1] for i in range(num_records)]) ) - result = AsyncResult(connection, 1, None, noop, noop) + result = AsyncResult(connection, 1, None, noop, noop, None) await result._run("CYPHER", {}, None, None, "r", None, None, None) await result._buffer_all() records = result._record_buffer.copy() @@ -667,7 +667,7 @@ async def test_data(num_records): @mark_async_test async def test_result_graph(records): connection = AsyncConnectionStub(records=records) - result = AsyncResult(connection, 1, None, noop, noop) + result = AsyncResult(connection, 1, None, noop, noop, None) await result._run("CYPHER", {}, None, None, "r", None, None, None) graph = await result.graph() assert isinstance(graph, Graph) @@ -760,7 +760,7 @@ async def test_result_graph(records): async def test_to_eager_result(records): summary = {"test_to_eager_result": uuid.uuid4()} connection = AsyncConnectionStub(records=records, summary_meta=summary) - result = AsyncResult(connection, 1, None, noop, noop) + result = AsyncResult(connection, 1, None, noop, noop, None) await result._run("CYPHER", {}, None, None, "r", None, None, None) eager_result = await result.to_eager_result() @@ -850,7 +850,7 @@ async def test_to_eager_result(records): @mark_async_test async def test_to_df(keys, values, types, instances, test_default_expand): connection = AsyncConnectionStub(records=Records(keys, values)) - result = AsyncResult(connection, 1, None, noop, noop) + result = AsyncResult(connection, 1, None, noop, noop, None) await result._run("CYPHER", {}, None, None, "r", None, None, None) if test_default_expand: df = await result.to_df() @@ -1061,7 +1061,7 @@ async def test_to_df_expand( keys, values, expected_columns, expected_rows, expected_types ): connection = AsyncConnectionStub(records=Records(keys, values)) - result = AsyncResult(connection, 1, None, noop, noop) + result = AsyncResult(connection, 1, None, noop, noop, None) await result._run("CYPHER", {}, None, None, "r", None, None, None) df = await result.to_df(expand=True) @@ -1299,7 +1299,7 @@ async def test_to_df_expand( @mark_async_test async def test_to_df_parse_dates(keys, values, expected_df, expand): connection = AsyncConnectionStub(records=Records(keys, values)) - result = AsyncResult(connection, 1, None, noop, noop) + result = AsyncResult(connection, 1, None, noop, noop, None) await result._run("CYPHER", {}, None, None, "r", None, None, None) df = await result.to_df(expand=expand, parse_dates=True) @@ -1314,7 +1314,7 @@ async def test_broken_hydration(nested): value_in = [value_in] records_in = Records(["foo", "bar"], [["foobar", value_in]]) connection = AsyncConnectionStub(records=records_in) - result = AsyncResult(connection, 1, None, noop, noop) + result = AsyncResult(connection, 1, None, noop, noop, None) await result._run("CYPHER", {}, None, None, "r", None, None, None) records_out = await AsyncUtil.list(result) assert len(records_out) == 1 @@ -1422,3 +1422,37 @@ async def test_notification_logging( f"Received notification from DBMS server: {formatted_notification}" ) assert caplog.messages[0] == expected_message + + +@pytest.mark.parametrize("async_cb", (True, False)) +@pytest.mark.parametrize("resolved_db", (..., None, "resolved_db")) +@mark_async_test +async def test_on_database_callback(async_cb, resolved_db): + cb_calls = [] + + if async_cb: + + async def db_callback(db): + nonlocal cb_calls + cb_calls.append(db) + + else: + + def db_callback(db): + nonlocal cb_calls + cb_calls.append(db) + + run_meta = {} + if resolved_db is not ...: + run_meta["db"] = resolved_db + connection = AsyncConnectionStub( + records=Records(["foo"], ()), run_meta=run_meta + ) + + result = AsyncResult(connection, 1, None, noop, noop, db_callback) + await result._run("CYPHER", {}, None, None, "r", None, None, None) + + if resolved_db in {..., None}: + assert cb_calls == [] + else: + assert cb_calls == [resolved_db] diff --git a/tests/unit/async_/work/test_session.py b/tests/unit/async_/work/test_session.py index 6ec6fac2..781f2402 100644 --- a/tests/unit/async_/work/test_session.py +++ b/tests/unit/async_/work/test_session.py @@ -22,14 +22,19 @@ AsyncManagedTransaction, AsyncSession, AsyncTransaction, + Auth, Bookmarks, unit_of_work, ) from neo4j._api import TelemetryAPI +from neo4j._async.home_db_cache import AsyncHomeDbCache from neo4j._async.io import ( + AcquisitionDatabase, AsyncBoltPool, AsyncNeo4jPool, ) +from neo4j._async_compat.util import AsyncUtil +from neo4j._auth_management import to_auth_dict from neo4j._conf import SessionConfig from neo4j.api import ( AsyncBookmarkManager, @@ -430,12 +435,12 @@ async def work(tx): assert call.kwargs["parameters"] == expected_params -@pytest.mark.parametrize("db", (None, "adb")) -@pytest.mark.parametrize("routing", (True, False)) +@pytest.mark.parametrize("db", (None, "adb")[:1]) +@pytest.mark.parametrize("routing", (True, False)[:1]) # no home db resolution when connected to Neo4j 4.3 or earlier -@pytest.mark.parametrize("home_db_gets_resolved", (True, False)) +@pytest.mark.parametrize("home_db_gets_resolved", (True, False)[:1]) @pytest.mark.parametrize( - "additional_session_bookmarks", (None, ["session", "bookmarks"]) + "additional_session_bookmarks", (None, ["session", "bookmarks"])[:1] ) @mark_async_test async def test_with_bookmark_manager( @@ -490,8 +495,10 @@ async def bmm_get_bookmarks(): async_fake_pool.update_routing_table.side_effect = ( update_routing_table_side_effect ) + async_fake_pool.is_direct_pool = False else: async_fake_pool.mock_add_spec(AsyncBoltPool) + async_fake_pool.is_direct_pool = True config = SessionConfig() config.bookmark_manager = bmm @@ -699,3 +706,171 @@ async def work(_): connection_mock.telemetry.assert_called_once() call_args = connection_mock.telemetry.call_args.args assert call_args[0] == TelemetryAPI.DRIVER + + +@pytest.mark.parametrize( + ("db", "pool_ssr", "pool_routing", "expect_cache_usage"), + ( + (db, ssr, routing, ssr and routing and not db) + for ssr in (True, False) + for routing in (True, False) + for db in (None, "mydb") + ), +) +@pytest.mark.parametrize("imp_user", (None, "imp_user")) +@pytest.mark.parametrize( + "auth", + (None, Auth(scheme="magic-auth", principal=None, credentials="tada")), +) +@mark_async_test +async def test_uses_home_db_cache_when_expected( + async_fake_pool, + mocker, + db, + pool_ssr, + pool_routing, + expect_cache_usage, + imp_user, + auth, +): + async_fake_pool.ssr_enabled = pool_ssr + if pool_routing: + async_fake_pool.is_direct_pool = False + async_fake_pool.mock_add_spec(AsyncNeo4jPool) + cache_spy = mocker.Mock(spec=AsyncHomeDbCache, wraps=AsyncHomeDbCache()) + cached_db = "nice_cached_home_db" + key = object() + cache_spy.compute_key.return_value = key + cache_spy.get.return_value = cached_db + async_fake_pool.home_db_cache = cache_spy + + config = SessionConfig() + config.impersonated_user = imp_user + config.auth = auth + config.database = db + + async with AsyncSession(async_fake_pool, config) as session: + await session.run("RETURN 1") + + if expect_cache_usage: + # assert using cache + assert cache_spy.mock_calls == [ + mocker.call.compute_key( + imp_user, to_auth_dict(auth) if auth else None + ), + mocker.call.get(key), + ] + # assert passing cache result as a guess to the pool + async_fake_pool.acquire.assert_awaited_once_with( + access_mode=mocker.ANY, + timeout=mocker.ANY, + database=AcquisitionDatabase(cached_db, guessed=True), + bookmarks=mocker.ANY, + auth=mocker.ANY, + liveness_check_timeout=mocker.ANY, + database_callback=mocker.ANY, + ) + else: + # assert not using cache + cache_spy.get.assert_not_called() + # assert passing a non-guess to the pool + async_fake_pool.acquire.assert_awaited_once_with( + access_mode=mocker.ANY, + timeout=mocker.ANY, + database=AcquisitionDatabase(db, guessed=False), + bookmarks=mocker.ANY, + auth=mocker.ANY, + liveness_check_timeout=mocker.ANY, + database_callback=mocker.ANY, + ) + + +@pytest.mark.parametrize( + ("db", "pool_ssr", "pool_routing", "expect_cache_usage"), + ( + (db, ssr, routing, ssr and routing and not db) + for ssr in (True, False) + for routing in (True, False) + for db in (None, "mydb") + ), +) +@pytest.mark.parametrize("resolution_at", ("route", "run", "begin")) +@mark_async_test +async def test_pinns_session_db_with_cache( + async_fake_pool, + mocker, + db, + pool_ssr, + pool_routing, + expect_cache_usage, + resolution_at, +): + async def resolve_db(): + if resolution_at == "route": + database_callback = async_fake_pool.acquire.call_args.kwargs[ + "database_callback" + ] + await AsyncUtil.callback(database_callback, resolved_db) + elif resolution_at == "run": + database_callback = res_mock.call_args.args[-1] + await AsyncUtil.callback(database_callback, resolved_db) + elif resolution_at == "begin": + database_callback = tx_mock.call_args.args[-1] + await AsyncUtil.callback(database_callback, resolved_db) + else: + raise ValueError(f"Unknown resolution_at: {resolution_at}") + + if resolution_at == "run": + res_mock = mocker.patch( + "neo4j._async.work.session.AsyncResult", autospec=True + ) + elif resolution_at == "begin": + tx_mock = mocker.patch( + "neo4j._async.work.session.AsyncTransaction", autospec=True + ) + + resolved_db = "resolved_db" + async_fake_pool.ssr_enabled = pool_ssr + if pool_routing: + async_fake_pool.is_direct_pool = False + async_fake_pool.mock_add_spec(AsyncNeo4jPool) + cache_spy = mocker.Mock(spec=AsyncHomeDbCache, wraps=AsyncHomeDbCache()) + key = object() + cache_spy.compute_key.return_value = key + async_fake_pool.home_db_cache = cache_spy + + config = SessionConfig() + config.database = db + + async with AsyncSession(async_fake_pool, config) as session: + if resolution_at == "begin": + async with await session.begin_transaction() as tx: + await tx.run("RETURN 1") + else: + await session.run("RETURN 1") + + if expect_cache_usage: + # assert never using cache to pin a database + assert not session._pinned_database + assert config.database == db + + await resolve_db() + + assert session._pinned_database + assert config.database == resolved_db + cache_spy.set.assert_called_once_with(key, resolved_db) + else: + if not pool_routing or db: + assert session._pinned_database + assert config.database == db + + await resolve_db() + + if not pool_routing or db: + assert session._pinned_database + assert config.database == db + cache_spy.set.assert_not_called() + else: + cache_spy.set.assert_called_once_with(key, resolved_db) + assert session._pinned_database + assert config.database == resolved_db diff --git a/tests/unit/async_/work/test_transaction.py b/tests/unit/async_/work/test_transaction.py index ab23ce7b..787fae82 100644 --- a/tests/unit/async_/work/test_transaction.py +++ b/tests/unit/async_/work/test_transaction.py @@ -12,8 +12,6 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - - from unittest.mock import MagicMock import pytest @@ -52,7 +50,7 @@ async def test_transaction_context_when_committing( on_error = mocker.AsyncMock() on_cancel = mocker.Mock() tx = AsyncTransaction( - async_fake_connection, 2, None, on_closed, on_error, on_cancel + async_fake_connection, 2, None, on_closed, on_error, on_cancel, None ) mock_commit = mocker.patch.object(tx, "_commit", wraps=tx._commit) mock_rollback = mocker.patch.object(tx, "_rollback", wraps=tx._rollback) @@ -88,7 +86,7 @@ async def test_transaction_context_with_explicit_rollback( on_error = mocker.AsyncMock() on_cancel = mocker.Mock() tx = AsyncTransaction( - async_fake_connection, 2, None, on_closed, on_error, on_cancel + async_fake_connection, 2, None, on_closed, on_error, on_cancel, None ) mock_commit = mocker.patch.object(tx, "_commit", wraps=tx._commit) mock_rollback = mocker.patch.object(tx, "_rollback", wraps=tx._rollback) @@ -120,7 +118,7 @@ class OopsError(RuntimeError): on_error = MagicMock() on_cancel = MagicMock() tx = AsyncTransaction( - async_fake_connection, 2, None, on_closed, on_error, on_cancel + async_fake_connection, 2, None, on_closed, on_error, on_cancel, None ) mock_commit = mocker.patch.object(tx, "_commit", wraps=tx._commit) mock_rollback = mocker.patch.object(tx, "_rollback", wraps=tx._rollback) @@ -141,7 +139,7 @@ async def test_transaction_run_takes_no_query_object(async_fake_connection): on_error = MagicMock() on_cancel = MagicMock() tx = AsyncTransaction( - async_fake_connection, 2, None, on_closed, on_error, on_cancel + async_fake_connection, 2, None, on_closed, on_error, on_cancel, None ) with pytest.raises(ValueError): await tx.run(Query("RETURN 1")) @@ -165,7 +163,7 @@ async def test_transaction_run_parameters( on_error = MagicMock() on_cancel = MagicMock() tx = AsyncTransaction( - async_fake_connection, 2, None, on_closed, on_error, on_cancel + async_fake_connection, 2, None, on_closed, on_error, on_cancel, None ) if not as_kwargs: params = {"parameters": params} @@ -187,7 +185,9 @@ async def test_transaction_run_parameters( async def test_transaction_rollbacks_on_open_connections( async_fake_connection, ): - tx = AsyncTransaction(async_fake_connection, 2, None, noop, noop, noop) + tx = AsyncTransaction( + async_fake_connection, 2, None, noop, noop, noop, None + ) async with tx as tx_: async_fake_connection.is_reset_mock.return_value = False async_fake_connection.is_reset_mock.reset_mock() @@ -201,7 +201,9 @@ async def test_transaction_rollbacks_on_open_connections( async def test_transaction_no_rollback_on_reset_connections( async_fake_connection, ): - tx = AsyncTransaction(async_fake_connection, 2, None, noop, noop, noop) + tx = AsyncTransaction( + async_fake_connection, 2, None, noop, noop, noop, None + ) async with tx as tx_: async_fake_connection.is_reset_mock.return_value = True async_fake_connection.is_reset_mock.reset_mock() @@ -215,7 +217,9 @@ async def test_transaction_no_rollback_on_reset_connections( async def test_transaction_no_rollback_on_closed_connections( async_fake_connection, ): - tx = AsyncTransaction(async_fake_connection, 2, None, noop, noop, noop) + tx = AsyncTransaction( + async_fake_connection, 2, None, noop, noop, noop, None + ) async with tx as tx_: async_fake_connection.closed.return_value = True async_fake_connection.closed.reset_mock() @@ -231,7 +235,9 @@ async def test_transaction_no_rollback_on_closed_connections( async def test_transaction_no_rollback_on_defunct_connections( async_fake_connection, ): - tx = AsyncTransaction(async_fake_connection, 2, None, noop, noop, noop) + tx = AsyncTransaction( + async_fake_connection, 2, None, noop, noop, noop, None + ) async with tx as tx_: async_fake_connection.defunct.return_value = True async_fake_connection.defunct.reset_mock() @@ -246,7 +252,9 @@ async def test_transaction_no_rollback_on_defunct_connections( @pytest.mark.parametrize("pipeline", (True, False)) @mark_async_test async def test_transaction_begin_pipelining( - async_fake_connection, pipeline + async_fake_connection, + pipeline, + mocker, ) -> None: tx = AsyncTransaction( async_fake_connection, 2, None, noop, noop, noop, None @@ -285,6 +293,7 @@ async def test_transaction_begin_pipelining( "notifications_disabled_classifications": ( notifications_disabled_classifications ), + "on_success": mocker.ANY, }, ), ] @@ -335,7 +344,7 @@ async def test_server_error_propagates(async_scripted_connection, error): raise ValueError(f"Unknown error type {error}") connection.set_script(script) - tx = AsyncTransaction(connection, 2, None, noop, noop, noop) + tx = AsyncTransaction(connection, 2, None, noop, noop, noop, None) res1 = await tx.run("UNWIND range(1, 1000) AS n RETURN n") assert await res1.__anext__() == {"n": 1} @@ -351,3 +360,45 @@ async def test_server_error_propagates(async_scripted_connection, error): await res1.__anext__() assert exc1.value is exc2.value.__cause__ + + +@pytest.mark.parametrize("async_cb", (True, False)) +@pytest.mark.parametrize("resolved_db", (..., None, "resolved_db")) +@mark_async_test +async def test_on_database_callback( + async_scripted_connection, async_cb, resolved_db +): + cb_calls = [] + + if async_cb: + + async def db_callback(db): + nonlocal cb_calls + cb_calls.append(db) + else: + + def db_callback(db): + nonlocal cb_calls + cb_calls.append(db) + + begin_meta = {} + if resolved_db is not ...: + begin_meta["db"] = resolved_db + connection = async_scripted_connection + connection.set_script( + [ + ("begin", {"on_success": (begin_meta,), "on_summary": None}), + ] + ) + + result = AsyncTransaction( + connection, 1, None, noop, noop, noop, db_callback + ) + await result._begin( + None, None, None, None, None, None, None, None, pipelined=False + ) + + if resolved_db in {..., None}: + assert cb_calls == [] + else: + assert cb_calls == [resolved_db] diff --git a/tests/unit/sync/fixtures/fake_connection.py b/tests/unit/sync/fixtures/fake_connection.py index 8785badb..ca9a5c80 100644 --- a/tests/unit/sync/fixtures/fake_connection.py +++ b/tests/unit/sync/fixtures/fake_connection.py @@ -127,7 +127,14 @@ def callback(): return func method_mock = parent.__getattr__(name) - if name in {"run", "commit", "pull", "rollback", "discard"}: + if name in { + "begin", + "run", + "commit", + "pull", + "rollback", + "discard", + }: method_mock.side_effect = build_message_handler(name) return method_mock @@ -218,7 +225,14 @@ def callback(): return func method_mock = parent.__getattr__(name) - if name in {"run", "commit", "pull", "rollback", "discard"}: + if name in { + "begin", + "run", + "commit", + "pull", + "rollback", + "discard", + }: method_mock.side_effect = build_message_handler(name) return method_mock diff --git a/tests/unit/sync/fixtures/fake_pool.py b/tests/unit/sync/fixtures/fake_pool.py index 38d2ac4d..855d935a 100644 --- a/tests/unit/sync/fixtures/fake_pool.py +++ b/tests/unit/sync/fixtures/fake_pool.py @@ -17,6 +17,7 @@ import pytest from neo4j._sync.config import PoolConfig +from neo4j._sync.home_db_cache import HomeDbCache from neo4j._sync.io._pool import IOPool @@ -32,6 +33,9 @@ def fake_pool(fake_connection_generator, mocker): pool.buffered_connection_mocks = [] pool.acquired_connection_mocks = [] pool.pool_config = PoolConfig() + pool.ssr_enabled = False + pool.is_direct_pool = True + pool.home_db_cache = HomeDbCache(enabled=False) def acquire_side_effect(*_, **__): if pool.buffered_connection_mocks: diff --git a/tests/unit/sync/io/test_neo4j_pool.py b/tests/unit/sync/io/test_neo4j_pool.py index 89b4d16b..13b9be4e 100644 --- a/tests/unit/sync/io/test_neo4j_pool.py +++ b/tests/unit/sync/io/test_neo4j_pool.py @@ -30,6 +30,7 @@ from neo4j._deadline import Deadline from neo4j._sync.config import PoolConfig from neo4j._sync.io import ( + AcquisitionDatabase, Bolt, Neo4jPool, ) @@ -53,11 +54,27 @@ WRITER1_ADDRESS = ResolvedAddress(("1.2.3.1", 9020), host_name="host") +def make_home_db_resolve(home_db): + def _home_db_resolve(db): + return db or home_db + + return _home_db_resolve + + +_default_db_resolve = make_home_db_resolve("neo4j") + + @pytest.fixture def custom_routing_opener(fake_connection_generator, mocker): - def make_opener(failures=None, get_readers=None): + def make_opener( + failures=None, + get_readers=None, + db_resolve=_default_db_resolve, + on_open=None, + ): def routing_side_effect(*args, **kwargs): nonlocal failures + opener_.route_requests.append(kwargs.get("database")) res = next(failures, None) if res is None: routers = [ @@ -70,16 +87,18 @@ def routing_side_effect(*args, **kwargs): else: readers = [str(READER1_ADDRESS)] writers = [str(WRITER1_ADDRESS)] - return [ - { - "ttl": 1000, - "servers": [ - {"addresses": routers, "role": "ROUTE"}, - {"addresses": readers, "role": "READ"}, - {"addresses": writers, "role": "WRITE"}, - ], - } - ] + rt = { + "ttl": 1000, + "servers": [ + {"addresses": routers, "role": "ROUTE"}, + {"addresses": readers, "role": "READ"}, + {"addresses": writers, "role": "WRITE"}, + ], + } + db = db_resolve(kwargs.get("database")) + if db is not ...: + rt["db"] = db + return [rt] raise res def open_(addr, auth, timeout): @@ -92,11 +111,16 @@ def open_(addr, auth, timeout): route_mock.side_effect = routing_side_effect connection.attach_mock(route_mock, "route") opener_.connections.append(connection) + + if callable(on_open): + on_open(connection) + return connection failures = iter(failures or []) opener_ = mocker.MagicMock() opener_.connections = [] + opener_.route_requests = [] opener_.side_effect = open_ return opener_ @@ -124,54 +148,101 @@ def _simple_pool(opener) -> Neo4jPool: ) +TEST_DB1 = AcquisitionDatabase("test_db1") +TEST_DB2 = AcquisitionDatabase("test_db2") + + +@pytest.mark.parametrize("guessed_db", (True, False)) @mark_sync_test -def test_acquires_new_routing_table_if_deleted(opener): +def test_acquires_new_routing_table_if_deleted( + custom_routing_opener, + guessed_db, +) -> None: + db = AcquisitionDatabase("test_db", guessed=guessed_db) + opener = custom_routing_opener(db_resolve=make_home_db_resolve(db.name)) pool = _simple_pool(opener) - cx = pool.acquire(READ_ACCESS, 30, "test_db", None, None, None) + cx = pool.acquire(READ_ACCESS, 30, db, None, None, None) pool.release(cx) - assert pool.routing_tables.get("test_db") + assert pool.routing_tables.get(db.name) + assert opener.route_requests == [None if guessed_db else db.name] + opener.route_requests = [] - del pool.routing_tables["test_db"] + del pool.routing_tables[db.name] - cx = pool.acquire(READ_ACCESS, 30, "test_db", None, None, None) + cx = pool.acquire(READ_ACCESS, 30, db, None, None, None) pool.release(cx) - assert pool.routing_tables.get("test_db") + assert pool.routing_tables.get(db.name) + assert opener.route_requests == [None if guessed_db else db.name] +@pytest.mark.parametrize("guessed_db", (True, False)) @mark_sync_test -def test_acquires_new_routing_table_if_stale(opener): +def test_acquires_new_routing_table_if_stale( + custom_routing_opener, + guessed_db, +) -> None: + db = AcquisitionDatabase("test_db", guessed=guessed_db) + opener = custom_routing_opener(db_resolve=make_home_db_resolve(db.name)) pool = _simple_pool(opener) - cx = pool.acquire(READ_ACCESS, 30, "test_db", None, None, None) + cx = pool.acquire(READ_ACCESS, 30, db, None, None, None) pool.release(cx) - assert pool.routing_tables.get("test_db") + assert pool.routing_tables.get(db.name) + assert opener.route_requests == [None if guessed_db else db.name] + opener.route_requests = [] - old_value = pool.routing_tables["test_db"].last_updated_time - pool.routing_tables["test_db"].ttl = 0 + old_value = pool.routing_tables[db.name].last_updated_time + pool.routing_tables[db.name].ttl = 0 - cx = pool.acquire(READ_ACCESS, 30, "test_db", None, None, None) + cx = pool.acquire(READ_ACCESS, 30, db, None, None, None) pool.release(cx) - assert pool.routing_tables["test_db"].last_updated_time > old_value + assert pool.routing_tables[db.name].last_updated_time > old_value + assert opener.route_requests == [None if guessed_db else db.name] @mark_sync_test def test_removes_old_routing_table(opener): pool = _simple_pool(opener) - cx = pool.acquire(READ_ACCESS, 30, "test_db1", None, None, None) + cx = pool.acquire(READ_ACCESS, 30, TEST_DB1, None, None, None) pool.release(cx) - assert pool.routing_tables.get("test_db1") - cx = pool.acquire(READ_ACCESS, 30, "test_db2", None, None, None) + assert pool.routing_tables.get(TEST_DB1.name) + cx = pool.acquire(READ_ACCESS, 30, TEST_DB2, None, None, None) pool.release(cx) - assert pool.routing_tables.get("test_db2") + assert pool.routing_tables.get(TEST_DB2.name) - old_value = pool.routing_tables["test_db1"].last_updated_time - pool.routing_tables["test_db1"].ttl = 0 - db2_rt = pool.routing_tables["test_db2"] + old_value = pool.routing_tables[TEST_DB1.name].last_updated_time + pool.routing_tables[TEST_DB1.name].ttl = 0 + db2_rt = pool.routing_tables[TEST_DB2.name] db2_rt.ttl = -RoutingConfig.routing_table_purge_delay - cx = pool.acquire(READ_ACCESS, 30, "test_db1", None, None, None) + cx = pool.acquire(READ_ACCESS, 30, TEST_DB1, None, None, None) pool.release(cx) - assert pool.routing_tables["test_db1"].last_updated_time > old_value - assert "test_db2" not in pool.routing_tables + assert pool.routing_tables[TEST_DB1.name].last_updated_time > old_value + assert TEST_DB2.name not in pool.routing_tables + + +@pytest.mark.parametrize("guessed_db", (True, False)) +@mark_sync_test +def test_db_resolution_callback(custom_routing_opener, guessed_db): + cb_calls = [] + + def cb(db_): + nonlocal cb_calls + cb_calls.append(db_) + + db = AcquisitionDatabase("test_db", guessed=guessed_db) + home_db = "home_db" + expected_target_db = home_db if db.guessed else db.name + + opener = custom_routing_opener(db_resolve=make_home_db_resolve(home_db)) + pool = _simple_pool(opener) + cx = pool.acquire( + READ_ACCESS, 30, db, None, None, None, database_callback=cb + ) + pool.release(cx) + + assert pool.routing_tables.get(expected_target_db) + assert opener.route_requests == [None if guessed_db else db.name] + assert cb_calls == [expected_target_db] @pytest.mark.parametrize("type_", ("r", "w")) @@ -181,7 +252,7 @@ def test_chooses_right_connection_type(opener, type_): cx1 = pool.acquire( READ_ACCESS if type_ == "r" else WRITE_ACCESS, 30, - "test_db", + TEST_DB1, None, None, None, @@ -196,9 +267,9 @@ def test_chooses_right_connection_type(opener, type_): @mark_sync_test def test_reuses_connection(opener): pool = _simple_pool(opener) - cx1 = pool.acquire(READ_ACCESS, 30, "test_db", None, None, None) + cx1 = pool.acquire(READ_ACCESS, 30, TEST_DB1, None, None, None) pool.release(cx1) - cx2 = pool.acquire(READ_ACCESS, 30, "test_db", None, None, None) + cx2 = pool.acquire(READ_ACCESS, 30, TEST_DB1, None, None, None) assert cx1 is cx2 @@ -216,7 +287,7 @@ def break_connection(): return None pool = _simple_pool(opener) - cx1 = pool.acquire(READ_ACCESS, 30, "test_db", None, None, None) + cx1 = pool.acquire(READ_ACCESS, 30, TEST_DB1, None, None, None) pool.release(cx1) assert cx1 in pool.connections[cx1.unresolved_address] # simulate connection going stale (e.g. exceeding idle timeout) and then @@ -226,7 +297,7 @@ def break_connection(): if break_on_close: cx_close_mock_side_effect = cx_close_mock.side_effect cx_close_mock.side_effect = break_connection - cx2 = pool.acquire(READ_ACCESS, 30, "test_db", None, None, None) + cx2 = pool.acquire(READ_ACCESS, 30, TEST_DB1, None, None, None) pool.release(cx2) if break_on_close: cx1.close.assert_called() @@ -241,12 +312,12 @@ def break_connection(): @mark_sync_test def test_does_not_close_stale_connections_in_use(opener): pool = _simple_pool(opener) - cx1 = pool.acquire(READ_ACCESS, 30, "test_db", None, None, None) + cx1 = pool.acquire(READ_ACCESS, 30, TEST_DB1, None, None, None) assert cx1 in pool.connections[cx1.unresolved_address] # simulate connection going stale (e.g. exceeding idle timeout) while being # in use cx1.stale.return_value = True - cx2 = pool.acquire(READ_ACCESS, 30, "test_db", None, None, None) + cx2 = pool.acquire(READ_ACCESS, 30, TEST_DB1, None, None, None) pool.release(cx2) cx1.close.assert_not_called() assert cx2 is not cx1 @@ -259,7 +330,7 @@ def test_does_not_close_stale_connections_in_use(opener): # it should be closed when trying to acquire the next connection cx1.close.assert_not_called() - cx3 = pool.acquire(READ_ACCESS, 30, "test_db", None, None, None) + cx3 = pool.acquire(READ_ACCESS, 30, TEST_DB1, None, None, None) pool.release(cx3) cx1.close.assert_called_once() assert cx2 is cx3 @@ -271,7 +342,7 @@ def test_does_not_close_stale_connections_in_use(opener): @mark_sync_test def test_release_resets_connections(opener): pool = _simple_pool(opener) - cx1 = pool.acquire(READ_ACCESS, 30, "test_db", None, None, None) + cx1 = pool.acquire(READ_ACCESS, 30, TEST_DB1, None, None, None) cx1.is_reset_mock.return_value = False cx1.is_reset_mock.reset_mock() pool.release(cx1) @@ -282,7 +353,7 @@ def test_release_resets_connections(opener): @mark_sync_test def test_release_does_not_resets_closed_connections(opener): pool = _simple_pool(opener) - cx1 = pool.acquire(READ_ACCESS, 30, "test_db", None, None, None) + cx1 = pool.acquire(READ_ACCESS, 30, TEST_DB1, None, None, None) cx1.closed.return_value = True cx1.closed.reset_mock() cx1.is_reset_mock.reset_mock() @@ -295,7 +366,7 @@ def test_release_does_not_resets_closed_connections(opener): @mark_sync_test def test_release_does_not_resets_defunct_connections(opener): pool = _simple_pool(opener) - cx1 = pool.acquire(READ_ACCESS, 30, "test_db", None, None, None) + cx1 = pool.acquire(READ_ACCESS, 30, TEST_DB1, None, None, None) cx1.defunct.return_value = True cx1.defunct.reset_mock() cx1.is_reset_mock.reset_mock() @@ -457,8 +528,8 @@ def close_side_effect(): # create pool with 2 idle connections pool = _simple_pool(opener) - cx1 = pool.acquire(READ_ACCESS, 30, "test_db", None, None, None) - cx2 = pool.acquire(READ_ACCESS, 30, "test_db", None, None, None) + cx1 = pool.acquire(READ_ACCESS, 30, TEST_DB1, None, None, None) + cx2 = pool.acquire(READ_ACCESS, 30, TEST_DB1, None, None, None) pool.release(cx1) pool.release(cx2) @@ -470,7 +541,7 @@ def close_side_effect(): # unreachable cx1.stale.return_value = True - cx3 = pool.acquire(READ_ACCESS, 30, "test_db", None, None, None) + cx3 = pool.acquire(READ_ACCESS, 30, TEST_DB1, None, None, None) assert cx3 is not cx1 assert cx3 is not cx2 @@ -479,11 +550,11 @@ def close_side_effect(): @mark_sync_test def test_failing_opener_leaves_connections_in_use_alone(opener): pool = _simple_pool(opener) - cx1 = pool.acquire(READ_ACCESS, 30, "test_db", None, None, None) + cx1 = pool.acquire(READ_ACCESS, 30, TEST_DB1, None, None, None) opener.side_effect = ServiceUnavailable("Server overloaded") with pytest.raises((ServiceUnavailable, SessionExpired)): - pool.acquire(READ_ACCESS, 30, "test_db", None, None, None) + pool.acquire(READ_ACCESS, 30, TEST_DB1, None, None, None) assert not cx1.closed() @@ -505,7 +576,7 @@ def test__acquire_new_later_without_room(opener): config = _pool_config() config.max_connection_pool_size = 1 pool = Neo4jPool(opener, config, WorkspaceConfig(), ROUTER1_ADDRESS) - _ = pool.acquire(READ_ACCESS, 30, "test_db", None, None, None) + _ = pool.acquire(READ_ACCESS, 30, TEST_DB1, None, None, None) # pool is full now assert pool.connections_reservations[READER1_ADDRESS] == 0 creator = pool._acquire_new_later(READER1_ADDRESS, None, Deadline(1)) @@ -559,13 +630,13 @@ def test_discovery_is_retried(custom_routing_opener, error): WorkspaceConfig(), ResolvedAddress(("1.2.3.1", 9999), host_name="host"), ) - cx1 = pool.acquire(READ_ACCESS, 30, "test_db", None, None, None) + cx1 = pool.acquire(READ_ACCESS, 30, TEST_DB1, None, None, None) pool.release(cx1) - pool.routing_tables.get("test_db").ttl = 0 + pool.routing_tables.get(TEST_DB1.name).ttl = 0 - cx2 = pool.acquire(READ_ACCESS, 30, "test_db", None, None, None) + cx2 = pool.acquire(READ_ACCESS, 30, TEST_DB1, None, None, None) pool.release(cx2) - assert pool.routing_tables.get("test_db") + assert pool.routing_tables.get(TEST_DB1.name) assert cx1 is cx2 @@ -611,12 +682,12 @@ def test_fast_failing_discovery(custom_routing_opener, error): WorkspaceConfig(), ResolvedAddress(("1.2.3.1", 9999), host_name="host"), ) - cx1 = pool.acquire(READ_ACCESS, 30, "test_db", None, None, None) + cx1 = pool.acquire(READ_ACCESS, 30, TEST_DB1, None, None, None) pool.release(cx1) - pool.routing_tables.get("test_db").ttl = 0 + pool.routing_tables.get(TEST_DB1.name).ttl = 0 with pytest.raises(error.__class__) as exc: - pool.acquire(READ_ACCESS, 30, "test_db", None, None, None) + pool.acquire(READ_ACCESS, 30, TEST_DB1, None, None, None) assert exc.value is error @@ -657,11 +728,11 @@ def test_connection_error_callback( config.auth = auth_manager pool = Neo4jPool(opener, config, WorkspaceConfig(), ROUTER1_ADDRESS) cxs_read = [ - pool.acquire(READ_ACCESS, 30, "test_db", None, None, None) + pool.acquire(READ_ACCESS, 30, TEST_DB1, None, None, None) for _ in range(5) ] cxs_write = [ - pool.acquire(WRITE_ACCESS, 30, "test_db", None, None, None) + pool.acquire(WRITE_ACCESS, 30, TEST_DB1, None, None, None) for _ in range(5) ] @@ -690,7 +761,7 @@ def test_connection_error_callback( @mark_sync_test def test_pool_closes_connections_dropped_from_rt(custom_routing_opener): - readers = {"db1": [str(READER1_ADDRESS)]} + readers = {TEST_DB1.name: [str(READER1_ADDRESS)]} def get_readers(database): return readers[database] @@ -700,7 +771,7 @@ def get_readers(database): pool = Neo4jPool( opener, _pool_config(), WorkspaceConfig(), ROUTER1_ADDRESS ) - cx1 = pool.acquire(READ_ACCESS, 30, "db1", None, None, None) + cx1 = pool.acquire(READ_ACCESS, 30, TEST_DB1, None, None, None) assert cx1.unresolved_address == READER1_ADDRESS pool.release(cx1) @@ -708,10 +779,10 @@ def get_readers(database): assert len(pool.connections[READER1_ADDRESS]) == 1 # force RT refresh, returning a different reader - del pool.routing_tables["db1"] - readers["db1"] = [str(READER2_ADDRESS)] + del pool.routing_tables[TEST_DB1.name] + readers[TEST_DB1.name] = [str(READER2_ADDRESS)] - cx2 = pool.acquire(READ_ACCESS, 30, "db1", None, None, None) + cx2 = pool.acquire(READ_ACCESS, 30, TEST_DB1, None, None, None) assert cx2.unresolved_address == READER2_ADDRESS cx1.close.assert_called_once() @@ -726,8 +797,8 @@ def test_pool_does_not_close_connections_dropped_from_rt_for_other_server( # no custom_routing_opener, ): readers = { - "db1": [str(READER1_ADDRESS), str(READER2_ADDRESS)], - "db2": [str(READER1_ADDRESS)], + TEST_DB1.name: [str(READER1_ADDRESS), str(READER2_ADDRESS)], + TEST_DB2.name: [str(READER1_ADDRESS)], } def get_readers(database): @@ -738,14 +809,14 @@ def get_readers(database): pool = Neo4jPool( opener, _pool_config(), WorkspaceConfig(), ROUTER1_ADDRESS ) - cx1 = pool.acquire(READ_ACCESS, 30, "db1", None, None, None) + cx1 = pool.acquire(READ_ACCESS, 30, TEST_DB1, None, None, None) pool.release(cx1) assert cx1.unresolved_address in {READER1_ADDRESS, READER2_ADDRESS} reader1_connection_count = len(pool.connections[READER1_ADDRESS]) reader2_connection_count = len(pool.connections[READER2_ADDRESS]) assert reader1_connection_count + reader2_connection_count == 1 - cx2 = pool.acquire(READ_ACCESS, 30, "db2", None, None, None) + cx2 = pool.acquire(READ_ACCESS, 30, TEST_DB2, None, None, None) pool.release(cx2) assert cx2.unresolved_address == READER1_ADDRESS cx1.close.assert_not_called() @@ -754,10 +825,10 @@ def get_readers(database): assert len(pool.connections[READER2_ADDRESS]) == reader2_connection_count # force RT refresh, returning a different reader - del pool.routing_tables["db2"] - readers["db2"] = [str(READER3_ADDRESS)] + del pool.routing_tables[TEST_DB2.name] + readers[TEST_DB2.name] = [str(READER3_ADDRESS)] - cx3 = pool.acquire(READ_ACCESS, 30, "db2", None, None, None) + cx3 = pool.acquire(READ_ACCESS, 30, TEST_DB2, None, None, None) pool.release(cx3) assert cx3.unresolved_address == READER3_ADDRESS @@ -767,3 +838,79 @@ def get_readers(database): assert len(pool.connections[READER1_ADDRESS]) == 1 assert len(pool.connections[READER2_ADDRESS]) == reader2_connection_count assert len(pool.connections[READER3_ADDRESS]) == 1 + + +@mark_sync_test +def test_tracks_ssr_connection_hints(custom_routing_opener): + connection_count = 0 + + def on_open(connection): + if connection.unresolved_address in { + ROUTER1_ADDRESS, + ROUTER2_ADDRESS, + ROUTER3_ADDRESS, + }: + connection.ssr_enabled = True + return + nonlocal connection_count + connection_count += 1 + connection.ssr_enabled = connection_count != 2 + + opener = custom_routing_opener(on_open=on_open) + pool = Neo4jPool( + opener, _pool_config(), WorkspaceConfig(), ROUTER1_ADDRESS + ) + + # no connection in pool => cannot know => defensive assumption: off + assert not pool.ssr_enabled + + # open 1st reader connection (supports SSR) + cx1 = pool.acquire(READ_ACCESS, 30, TEST_DB1, None, None, None) + assert cx1.ssr_enabled # double check we got the mocking right + + assert pool.ssr_enabled + + # open 2nd reader connection (does not support SSR) + cx2 = pool.acquire(READ_ACCESS, 30, TEST_DB1, None, None, None) + assert not cx2.ssr_enabled # double check we got the mocking right + + assert not pool.ssr_enabled + + # open 3rd reader connection (supports SSR) + cx3 = pool.acquire(READ_ACCESS, 30, TEST_DB1, None, None, None) + assert cx3.ssr_enabled # double check we got the mocking right + + assert not pool.ssr_enabled + + pool.release(cx1) + pool.release(cx2) + pool.release(cx3) + + assert not pool.ssr_enabled + + cxs = [ + pool.acquire(READ_ACCESS, 30, TEST_DB1, None, None, None) + for _ in range(3) + ] + assert sum(not c.ssr_enabled for c in cxs) == 1 # double check + + for cx in (cx for cx in cxs if not cx.ssr_enabled): + cx.close() + + # after the single connection without SSR support is closed + for cx in cxs: + pool.release(cx) + + # force pool cleaning up all stale connections: + cxs = [ + pool.acquire(READ_ACCESS, 30, TEST_DB1, None, None, None) + for _ in range(3) + ] + assert all(cx.ssr_enabled for cx in cxs) # double check + + assert pool.ssr_enabled + + for cx in cxs: + pool.release(cx) + + assert pool.ssr_enabled diff --git a/tests/unit/sync/test_driver.py b/tests/unit/sync/test_driver.py index a508c108..456a3250 100644 --- a/tests/unit/sync/test_driver.py +++ b/tests/unit/sync/test_driver.py @@ -276,6 +276,7 @@ def test_driver_opens_write_session_by_default( bookmarks=mocker.ANY, auth=mocker.ANY, liveness_check_timeout=mocker.ANY, + database_callback=mocker.ANY, ) tx._begin.assert_called_once_with( mocker.ANY, diff --git a/tests/unit/sync/work/test_result.py b/tests/unit/sync/work/test_result.py index 6bb71e8b..4f87da73 100644 --- a/tests/unit/sync/work/test_result.py +++ b/tests/unit/sync/work/test_result.py @@ -315,7 +315,7 @@ def fetch_and_compare_all_records( @mark_sync_test def test_result_iteration(method, records): connection = ConnectionStub(records=Records(["x"], records)) - result = Result(connection, 2, None, noop, noop) + result = Result(connection, 2, None, noop, noop, None) result._run("CYPHER", {}, None, None, "r", None, None, None) fetch_and_compare_all_records(result, "x", records, method) @@ -324,7 +324,7 @@ def test_result_iteration(method, records): def test_result_iteration_mixed_methods(): records = [[i] for i in range(10)] connection = ConnectionStub(records=Records(["x"], records)) - result = Result(connection, 4, None, noop, noop) + result = Result(connection, 4, None, noop, noop, None) result._run("CYPHER", {}, None, None, "r", None, None, None) iter1 = Util.iter(result) iter2 = Util.iter(result) @@ -372,9 +372,9 @@ def test_parallel_result_iteration(method, invert_fetch): connection = ConnectionStub( records=(Records(["x"], records1), Records(["x"], records2)) ) - result1 = Result(connection, 2, None, noop, noop) + result1 = Result(connection, 2, None, noop, noop, None) result1._run("CYPHER1", {}, None, None, "r", None, None, None) - result2 = Result(connection, 2, None, noop, noop) + result2 = Result(connection, 2, None, noop, noop, None) result2._run("CYPHER2", {}, None, None, "r", None, None, None) if invert_fetch: fetch_and_compare_all_records(result2, "x", records2, method) @@ -395,9 +395,9 @@ def test_interwoven_result_iteration(method, invert_fetch): connection = ConnectionStub( records=(Records(["x"], records1), Records(["y"], records2)) ) - result1 = Result(connection, 2, None, noop, noop) + result1 = Result(connection, 2, None, noop, noop, None) result1._run("CYPHER1", {}, None, None, "r", None, None, None) - result2 = Result(connection, 2, None, noop, noop) + result2 = Result(connection, 2, None, noop, noop, None) result2._run("CYPHER2", {}, None, None, "r", None, None, None) start = 0 for n in (1, 2, 3, 1, None): @@ -424,7 +424,7 @@ def test_interwoven_result_iteration(method, invert_fetch): @mark_sync_test def test_result_peek(records, fetch_size): connection = ConnectionStub(records=Records(["x"], records)) - result = Result(connection, fetch_size, None, noop, noop) + result = Result(connection, fetch_size, None, noop, noop, None) result._run("CYPHER", {}, None, None, "r", None, None, None) for i in range(len(records) + 1): record = result.peek() @@ -447,7 +447,7 @@ def test_result_single_non_strict(records, fetch_size, default): kwargs["strict"] = False connection = ConnectionStub(records=Records(["x"], records)) - result = Result(connection, fetch_size, None, noop, noop) + result = Result(connection, fetch_size, None, noop, noop, None) result._run("CYPHER", {}, None, None, "r", None, None, None) if len(records) == 0: assert result.single(**kwargs) is None @@ -466,7 +466,7 @@ def test_result_single_non_strict(records, fetch_size, default): @mark_sync_test def test_result_single_strict(records, fetch_size): connection = ConnectionStub(records=Records(["x"], records)) - result = Result(connection, fetch_size, None, noop, noop) + result = Result(connection, fetch_size, None, noop, noop, None) result._run("CYPHER", {}, None, None, "r", None, None, None) if len(records) != 1: with pytest.raises(ResultNotSingleError) as exc: @@ -490,7 +490,7 @@ def test_result_single_strict(records, fetch_size): @mark_sync_test def test_result_single_exhausts_records(records, fetch_size, strict): connection = ConnectionStub(records=Records(["x"], records)) - result = Result(connection, fetch_size, None, noop, noop) + result = Result(connection, fetch_size, None, noop, noop, None) result._run("CYPHER", {}, None, None, "r", None, None, None) try: with warnings.catch_warnings(): @@ -512,7 +512,7 @@ def test_result_single_exhausts_records(records, fetch_size, strict): @mark_sync_test def test_result_fetch(records, fetch_size, strict): connection = ConnectionStub(records=Records(["x"], records)) - result = Result(connection, fetch_size, None, noop, noop) + result = Result(connection, fetch_size, None, noop, noop, None) result._run("CYPHER", {}, None, None, "r", None, None, None) assert result.fetch(0) == [] assert result.fetch(-1) == [] @@ -524,7 +524,7 @@ def test_result_fetch(records, fetch_size, strict): @mark_sync_test def test_keys_are_available_before_and_after_stream(): connection = ConnectionStub(records=Records(["x"], [[1], [2]])) - result = Result(connection, 1, None, noop, noop) + result = Result(connection, 1, None, noop, noop, None) result._run("CYPHER", {}, None, None, "r", None, None, None) assert list(result.keys()) == ["x"] Util.list(result) @@ -540,7 +540,7 @@ def test_consume(records, consume_one, summary_meta, consume_times): connection = ConnectionStub( records=Records(["x"], records), summary_meta=summary_meta ) - result = Result(connection, 1, None, noop, noop) + result = Result(connection, 1, None, noop, noop, None) result._run("CYPHER", {}, None, None, "r", None, None, None) if consume_one: with suppress(StopIteration): @@ -574,7 +574,7 @@ def test_time_in_summary(t_first, t_last): summary_meta=summary_meta, ) - result = Result(connection, 1, None, noop, noop) + result = Result(connection, 1, None, noop, noop, None) result._run("CYPHER", {}, None, None, "r", None, None, None) summary = result.consume() @@ -596,7 +596,7 @@ def test_time_in_summary(t_first, t_last): def test_counts_in_summary(): connection = ConnectionStub(records=Records(["n"], [[1], [2]])) - result = Result(connection, 1, None, noop, noop) + result = Result(connection, 1, None, noop, noop, None) result._run("CYPHER", {}, None, None, "r", None, None, None) summary = result.consume() @@ -610,7 +610,7 @@ def test_query_type(query_type): records=Records(["n"], [[1], [2]]), summary_meta={"type": query_type} ) - result = Result(connection, 1, None, noop, noop) + result = Result(connection, 1, None, noop, noop, None) result._run("CYPHER", {}, None, None, "r", None, None, None) summary = result.consume() @@ -625,7 +625,7 @@ def test_data(num_records): records=Records(["n"], [[i + 1] for i in range(num_records)]) ) - result = Result(connection, 1, None, noop, noop) + result = Result(connection, 1, None, noop, noop, None) result._run("CYPHER", {}, None, None, "r", None, None, None) result._buffer_all() records = result._record_buffer.copy() @@ -667,7 +667,7 @@ def test_data(num_records): @mark_sync_test def test_result_graph(records): connection = ConnectionStub(records=records) - result = Result(connection, 1, None, noop, noop) + result = Result(connection, 1, None, noop, noop, None) result._run("CYPHER", {}, None, None, "r", None, None, None) graph = result.graph() assert isinstance(graph, Graph) @@ -760,7 +760,7 @@ def test_result_graph(records): def test_to_eager_result(records): summary = {"test_to_eager_result": uuid.uuid4()} connection = ConnectionStub(records=records, summary_meta=summary) - result = Result(connection, 1, None, noop, noop) + result = Result(connection, 1, None, noop, noop, None) result._run("CYPHER", {}, None, None, "r", None, None, None) eager_result = result.to_eager_result() @@ -850,7 +850,7 @@ def test_to_eager_result(records): @mark_sync_test def test_to_df(keys, values, types, instances, test_default_expand): connection = ConnectionStub(records=Records(keys, values)) - result = Result(connection, 1, None, noop, noop) + result = Result(connection, 1, None, noop, noop, None) result._run("CYPHER", {}, None, None, "r", None, None, None) if test_default_expand: df = result.to_df() @@ -1061,7 +1061,7 @@ def test_to_df_expand( keys, values, expected_columns, expected_rows, expected_types ): connection = ConnectionStub(records=Records(keys, values)) - result = Result(connection, 1, None, noop, noop) + result = Result(connection, 1, None, noop, noop, None) result._run("CYPHER", {}, None, None, "r", None, None, None) df = result.to_df(expand=True) @@ -1299,7 +1299,7 @@ def test_to_df_expand( @mark_sync_test def test_to_df_parse_dates(keys, values, expected_df, expand): connection = ConnectionStub(records=Records(keys, values)) - result = Result(connection, 1, None, noop, noop) + result = Result(connection, 1, None, noop, noop, None) result._run("CYPHER", {}, None, None, "r", None, None, None) df = result.to_df(expand=expand, parse_dates=True) @@ -1314,7 +1314,7 @@ def test_broken_hydration(nested): value_in = [value_in] records_in = Records(["foo", "bar"], [["foobar", value_in]]) connection = ConnectionStub(records=records_in) - result = Result(connection, 1, None, noop, noop) + result = Result(connection, 1, None, noop, noop, None) result._run("CYPHER", {}, None, None, "r", None, None, None) records_out = Util.list(result) assert len(records_out) == 1 @@ -1422,3 +1422,37 @@ def test_notification_logging( f"Received notification from DBMS server: {formatted_notification}" ) assert caplog.messages[0] == expected_message + + +@pytest.mark.parametrize("cb", (True, False)) +@pytest.mark.parametrize("resolved_db", (..., None, "resolved_db")) +@mark_sync_test +def test_on_database_callback(cb, resolved_db): + cb_calls = [] + + if cb: + + def db_callback(db): + nonlocal cb_calls + cb_calls.append(db) + + else: + + def db_callback(db): + nonlocal cb_calls + cb_calls.append(db) + + run_meta = {} + if resolved_db is not ...: + run_meta["db"] = resolved_db + connection = ConnectionStub( + records=Records(["foo"], ()), run_meta=run_meta + ) + + result = Result(connection, 1, None, noop, noop, db_callback) + result._run("CYPHER", {}, None, None, "r", None, None, None) + + if resolved_db in {..., None}: + assert cb_calls == [] + else: + assert cb_calls == [resolved_db] diff --git a/tests/unit/sync/work/test_session.py b/tests/unit/sync/work/test_session.py index 52843be1..c081bb27 100644 --- a/tests/unit/sync/work/test_session.py +++ b/tests/unit/sync/work/test_session.py @@ -19,6 +19,7 @@ import pytest from neo4j import ( + Auth, Bookmarks, ManagedTransaction, Session, @@ -26,8 +27,12 @@ unit_of_work, ) from neo4j._api import TelemetryAPI +from neo4j._async_compat.util import Util +from neo4j._auth_management import to_auth_dict from neo4j._conf import SessionConfig +from neo4j._sync.home_db_cache import HomeDbCache from neo4j._sync.io import ( + AcquisitionDatabase, BoltPool, Neo4jPool, ) @@ -430,12 +435,12 @@ def work(tx): assert call.kwargs["parameters"] == expected_params -@pytest.mark.parametrize("db", (None, "adb")) -@pytest.mark.parametrize("routing", (True, False)) +@pytest.mark.parametrize("db", (None, "adb")[:1]) +@pytest.mark.parametrize("routing", (True, False)[:1]) # no home db resolution when connected to Neo4j 4.3 or earlier -@pytest.mark.parametrize("home_db_gets_resolved", (True, False)) +@pytest.mark.parametrize("home_db_gets_resolved", (True, False)[:1]) @pytest.mark.parametrize( - "additional_session_bookmarks", (None, ["session", "bookmarks"]) + "additional_session_bookmarks", (None, ["session", "bookmarks"])[:1] ) @mark_sync_test def test_with_bookmark_manager( @@ -490,8 +495,10 @@ def bmm_get_bookmarks(): fake_pool.update_routing_table.side_effect = ( update_routing_table_side_effect ) + fake_pool.is_direct_pool = False else: fake_pool.mock_add_spec(BoltPool) + fake_pool.is_direct_pool = True config = SessionConfig() config.bookmark_manager = bmm @@ -699,3 +706,171 @@ def work(_): connection_mock.telemetry.assert_called_once() call_args = connection_mock.telemetry.call_args.args assert call_args[0] == TelemetryAPI.DRIVER + + +@pytest.mark.parametrize( + ("db", "pool_ssr", "pool_routing", "expect_cache_usage"), + ( + (db, ssr, routing, ssr and routing and not db) + for ssr in (True, False) + for routing in (True, False) + for db in (None, "mydb") + ), +) +@pytest.mark.parametrize("imp_user", (None, "imp_user")) +@pytest.mark.parametrize( + "auth", + (None, Auth(scheme="magic-auth", principal=None, credentials="tada")), +) +@mark_sync_test +def test_uses_home_db_cache_when_expected( + fake_pool, + mocker, + db, + pool_ssr, + pool_routing, + expect_cache_usage, + imp_user, + auth, +): + fake_pool.ssr_enabled = pool_ssr + if pool_routing: + fake_pool.is_direct_pool = False + fake_pool.mock_add_spec(Neo4jPool) + cache_spy = mocker.Mock(spec=HomeDbCache, wraps=HomeDbCache()) + cached_db = "nice_cached_home_db" + key = object() + cache_spy.compute_key.return_value = key + cache_spy.get.return_value = cached_db + fake_pool.home_db_cache = cache_spy + + config = SessionConfig() + config.impersonated_user = imp_user + config.auth = auth + config.database = db + + with Session(fake_pool, config) as session: + session.run("RETURN 1") + + if expect_cache_usage: + # assert using cache + assert cache_spy.mock_calls == [ + mocker.call.compute_key( + imp_user, to_auth_dict(auth) if auth else None + ), + mocker.call.get(key), + ] + # assert passing cache result as a guess to the pool + fake_pool.acquire.assert_called_once_with( + access_mode=mocker.ANY, + timeout=mocker.ANY, + database=AcquisitionDatabase(cached_db, guessed=True), + bookmarks=mocker.ANY, + auth=mocker.ANY, + liveness_check_timeout=mocker.ANY, + database_callback=mocker.ANY, + ) + else: + # assert not using cache + cache_spy.get.assert_not_called() + # assert passing a non-guess to the pool + fake_pool.acquire.assert_called_once_with( + access_mode=mocker.ANY, + timeout=mocker.ANY, + database=AcquisitionDatabase(db, guessed=False), + bookmarks=mocker.ANY, + auth=mocker.ANY, + liveness_check_timeout=mocker.ANY, + database_callback=mocker.ANY, + ) + + +@pytest.mark.parametrize( + ("db", "pool_ssr", "pool_routing", "expect_cache_usage"), + ( + (db, ssr, routing, ssr and routing and not db) + for ssr in (True, False) + for routing in (True, False) + for db in (None, "mydb") + ), +) +@pytest.mark.parametrize("resolution_at", ("route", "run", "begin")) +@mark_sync_test +def test_pinns_session_db_with_cache( + fake_pool, + mocker, + db, + pool_ssr, + pool_routing, + expect_cache_usage, + resolution_at, +): + def resolve_db(): + if resolution_at == "route": + database_callback = fake_pool.acquire.call_args.kwargs[ + "database_callback" + ] + Util.callback(database_callback, resolved_db) + elif resolution_at == "run": + database_callback = res_mock.call_args.args[-1] + Util.callback(database_callback, resolved_db) + elif resolution_at == "begin": + database_callback = tx_mock.call_args.args[-1] + Util.callback(database_callback, resolved_db) + else: + raise ValueError(f"Unknown resolution_at: {resolution_at}") + + if resolution_at == "run": + res_mock = mocker.patch( + "neo4j._sync.work.session.Result", autospec=True + ) + elif resolution_at == "begin": + tx_mock = mocker.patch( + "neo4j._sync.work.session.Transaction", autospec=True + ) + + resolved_db = "resolved_db" + fake_pool.ssr_enabled = pool_ssr + if pool_routing: + fake_pool.is_direct_pool = False + fake_pool.mock_add_spec(Neo4jPool) + cache_spy = mocker.Mock(spec=HomeDbCache, wraps=HomeDbCache()) + key = object() + cache_spy.compute_key.return_value = key + fake_pool.home_db_cache = cache_spy + + config = SessionConfig() + config.database = db + + with Session(fake_pool, config) as session: + if resolution_at == "begin": + with session.begin_transaction() as tx: + tx.run("RETURN 1") + else: + session.run("RETURN 1") + + if expect_cache_usage: + # assert never using cache to pin a database + assert not session._pinned_database + assert config.database == db + + resolve_db() + + assert session._pinned_database + assert config.database == resolved_db + cache_spy.set.assert_called_once_with(key, resolved_db) + else: + if not pool_routing or db: + assert session._pinned_database + assert config.database == db + + resolve_db() + + if not pool_routing or db: + assert session._pinned_database + assert config.database == db + cache_spy.set.assert_not_called() + else: + cache_spy.set.assert_called_once_with(key, resolved_db) + assert session._pinned_database + assert config.database == resolved_db diff --git a/tests/unit/sync/work/test_transaction.py b/tests/unit/sync/work/test_transaction.py index 683ff45d..b78768a2 100644 --- a/tests/unit/sync/work/test_transaction.py +++ b/tests/unit/sync/work/test_transaction.py @@ -12,8 +12,6 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - - from unittest.mock import MagicMock import pytest @@ -52,7 +50,7 @@ def test_transaction_context_when_committing( on_error = mocker.MagicMock() on_cancel = mocker.Mock() tx = Transaction( - fake_connection, 2, None, on_closed, on_error, on_cancel + fake_connection, 2, None, on_closed, on_error, on_cancel, None ) mock_commit = mocker.patch.object(tx, "_commit", wraps=tx._commit) mock_rollback = mocker.patch.object(tx, "_rollback", wraps=tx._rollback) @@ -88,7 +86,7 @@ def test_transaction_context_with_explicit_rollback( on_error = mocker.MagicMock() on_cancel = mocker.Mock() tx = Transaction( - fake_connection, 2, None, on_closed, on_error, on_cancel + fake_connection, 2, None, on_closed, on_error, on_cancel, None ) mock_commit = mocker.patch.object(tx, "_commit", wraps=tx._commit) mock_rollback = mocker.patch.object(tx, "_rollback", wraps=tx._rollback) @@ -120,7 +118,7 @@ class OopsError(RuntimeError): on_error = MagicMock() on_cancel = MagicMock() tx = Transaction( - fake_connection, 2, None, on_closed, on_error, on_cancel + fake_connection, 2, None, on_closed, on_error, on_cancel, None ) mock_commit = mocker.patch.object(tx, "_commit", wraps=tx._commit) mock_rollback = mocker.patch.object(tx, "_rollback", wraps=tx._rollback) @@ -141,7 +139,7 @@ def test_transaction_run_takes_no_query_object(fake_connection): on_error = MagicMock() on_cancel = MagicMock() tx = Transaction( - fake_connection, 2, None, on_closed, on_error, on_cancel + fake_connection, 2, None, on_closed, on_error, on_cancel, None ) with pytest.raises(ValueError): tx.run(Query("RETURN 1")) @@ -165,7 +163,7 @@ def test_transaction_run_parameters( on_error = MagicMock() on_cancel = MagicMock() tx = Transaction( - fake_connection, 2, None, on_closed, on_error, on_cancel + fake_connection, 2, None, on_closed, on_error, on_cancel, None ) if not as_kwargs: params = {"parameters": params} @@ -187,7 +185,9 @@ def test_transaction_run_parameters( def test_transaction_rollbacks_on_open_connections( fake_connection, ): - tx = Transaction(fake_connection, 2, None, noop, noop, noop) + tx = Transaction( + fake_connection, 2, None, noop, noop, noop, None + ) with tx as tx_: fake_connection.is_reset_mock.return_value = False fake_connection.is_reset_mock.reset_mock() @@ -201,7 +201,9 @@ def test_transaction_rollbacks_on_open_connections( def test_transaction_no_rollback_on_reset_connections( fake_connection, ): - tx = Transaction(fake_connection, 2, None, noop, noop, noop) + tx = Transaction( + fake_connection, 2, None, noop, noop, noop, None + ) with tx as tx_: fake_connection.is_reset_mock.return_value = True fake_connection.is_reset_mock.reset_mock() @@ -215,7 +217,9 @@ def test_transaction_no_rollback_on_reset_connections( def test_transaction_no_rollback_on_closed_connections( fake_connection, ): - tx = Transaction(fake_connection, 2, None, noop, noop, noop) + tx = Transaction( + fake_connection, 2, None, noop, noop, noop, None + ) with tx as tx_: fake_connection.closed.return_value = True fake_connection.closed.reset_mock() @@ -231,7 +235,9 @@ def test_transaction_no_rollback_on_closed_connections( def test_transaction_no_rollback_on_defunct_connections( fake_connection, ): - tx = Transaction(fake_connection, 2, None, noop, noop, noop) + tx = Transaction( + fake_connection, 2, None, noop, noop, noop, None + ) with tx as tx_: fake_connection.defunct.return_value = True fake_connection.defunct.reset_mock() @@ -246,7 +252,9 @@ def test_transaction_no_rollback_on_defunct_connections( @pytest.mark.parametrize("pipeline", (True, False)) @mark_sync_test def test_transaction_begin_pipelining( - fake_connection, pipeline + fake_connection, + pipeline, + mocker, ) -> None: tx = Transaction( fake_connection, 2, None, noop, noop, noop, None @@ -285,6 +293,7 @@ def test_transaction_begin_pipelining( "notifications_disabled_classifications": ( notifications_disabled_classifications ), + "on_success": mocker.ANY, }, ), ] @@ -335,7 +344,7 @@ def test_server_error_propagates(scripted_connection, error): raise ValueError(f"Unknown error type {error}") connection.set_script(script) - tx = Transaction(connection, 2, None, noop, noop, noop) + tx = Transaction(connection, 2, None, noop, noop, noop, None) res1 = tx.run("UNWIND range(1, 1000) AS n RETURN n") assert res1.__next__() == {"n": 1} @@ -351,3 +360,45 @@ def test_server_error_propagates(scripted_connection, error): res1.__next__() assert exc1.value is exc2.value.__cause__ + + +@pytest.mark.parametrize("cb", (True, False)) +@pytest.mark.parametrize("resolved_db", (..., None, "resolved_db")) +@mark_sync_test +def test_on_database_callback( + scripted_connection, cb, resolved_db +): + cb_calls = [] + + if cb: + + def db_callback(db): + nonlocal cb_calls + cb_calls.append(db) + else: + + def db_callback(db): + nonlocal cb_calls + cb_calls.append(db) + + begin_meta = {} + if resolved_db is not ...: + begin_meta["db"] = resolved_db + connection = scripted_connection + connection.set_script( + [ + ("begin", {"on_success": (begin_meta,), "on_summary": None}), + ] + ) + + result = Transaction( + connection, 1, None, noop, noop, noop, db_callback + ) + result._begin( + None, None, None, None, None, None, None, None, pipelined=False + ) + + if resolved_db in {..., None}: + assert cb_calls == [] + else: + assert cb_calls == [resolved_db] From d3d34901e5c6f74bf387acbb121be0946d4024ba Mon Sep 17 00:00:00 2001 From: Robsdedude Date: Thu, 14 Nov 2024 15:35:11 +0100 Subject: [PATCH 11/26] Docs: remove left-overs from first home db cache spike --- docs/source/api.rst | 41 +++++++------------------------------- docs/source/async_api.rst | 3 --- src/neo4j/_async/driver.py | 25 ----------------------- src/neo4j/_sync/driver.py | 25 ----------------------- 4 files changed, 7 insertions(+), 87 deletions(-) diff --git a/docs/source/api.rst b/docs/source/api.rst index ff6ed1d5..1220629d 100644 --- a/docs/source/api.rst +++ b/docs/source/api.rst @@ -262,9 +262,6 @@ Closing a driver will immediately shut down all connections in the pool. :data:`None` (default) uses the database configured on the server side. - Depending on the :ref:`max-home-database-delay-ref` configuration, - propagation of changes to the server side default might not be - immediate. .. Note:: It is recommended to always specify the database explicitly @@ -403,7 +400,6 @@ Additional configuration can be provided via the :class:`neo4j.Driver` construct + :ref:`liveness-check-timeout-ref` + :ref:`max-connection-pool-size-ref` + :ref:`max-transaction-retry-time-ref` -+ :ref:`max-home-database-delay-ref` + :ref:`resolver-ref` + :ref:`trust-ref` + :ref:`ssl-context-ref` @@ -526,26 +522,6 @@ The maximum total number of connections allowed, per host (i.e. cluster nodes), :Default: ``30.0`` -.. _max-home-database-delay-ref: - -``max_home_database_delay`` ---------------------------- -Defines an upper bound for how long (in seconds) a resolved home database can be cached. - -Set this value to ``0`` to prohibit any caching. -This likely incurs a significant performance penalty (driver and server side). -Set this value to ``float("inf")`` to allow the driver to cache resolutions forever. - -Note that in future driver/protocol versions, this setting might have no effect. - -:Type: ``float`` -:Default: ``5.0`` - -.. versionadded:: 5.x - -.. seealso:: :meth:`Driver.force_home_database_resolution` - - .. _resolver-ref: ``resolver`` @@ -1057,16 +1033,13 @@ Specifically, the following applies: instance, if the user's home database name is 'movies' and the server supplies it to the driver upon database name fetching for the session, all queries within that session are executed with the explicit database - name 'movies' supplied. Changes to the user's home database will only be - picked up by future sessions. There might be an additional delay depending - on the :ref:`max-home-database-delay-ref` configuration. Resolving the - user's home database name requires additional network communication. - Therefore, it is either recommended to either specify the database name - explicitly or set the home database delay appropriately. - In clustered environments, it is strongly recommended to avoid a single - point of failure. For instance, by ensuring that the connection URI - resolves to multiple endpoints. For older Bolt protocol versions the - behavior is the same as described for the **bolt schemes** above. + name 'movies' supplied. Any change to the user’s home database is + reflected only in sessions created after such change takes effect. This + behavior may requires additional network communication. In clustered + environments, it is strongly recommended to avoid a single point of + failure. For instance, by ensuring that the connection URI resolves to + multiple endpoints. For older Bolt protocol versions the behavior is the + same as described for the **bolt schemes** above. .. code-block:: python diff --git a/docs/source/async_api.rst b/docs/source/async_api.rst index ccbc6d75..6c6e62e4 100644 --- a/docs/source/async_api.rst +++ b/docs/source/async_api.rst @@ -249,9 +249,6 @@ Closing a driver will immediately shut down all connections in the pool. :data:`None` (default) uses the database configured on the server side. - Depending on the :ref:`max-home-database-delay-ref` configuration, - propagation of changes to the server side default might not be - immediate. .. Note:: It is recommended to always specify the database explicitly diff --git a/src/neo4j/_async/driver.py b/src/neo4j/_async/driver.py index dce77d35..a5420318 100644 --- a/src/neo4j/_async/driver.py +++ b/src/neo4j/_async/driver.py @@ -801,9 +801,6 @@ async def example(driver: neo4j.AsyncDriver) -> int: :data:`None` (default) uses the database configured on the server side. - Depending on the :ref:`max-home-database-delay-ref` configuration, - propagation of changes to the server side default might not be - immediate. .. Note:: It is recommended to always specify the database explicitly @@ -1300,28 +1297,6 @@ async def _get_server_info(self, session_config) -> ServerInfo: async with self._session(session_config) as session: return await session._get_server_info() - def force_home_database_resolution(self) -> None: - """Force the driver to resolve all home databases (again). - - The resolution is lazy and will only happen when the driver needs to - know the home database. - In practice, this means that the driver will flush the cache - configured by `max_home_database_delay`. - - This method is for instance useful when an application has changed a - user's home database, and the same application wants to pick up the - change in the next session while wanting to avoid setting - `max_home_database_delay` to `0` because of the performance penalty. - - .. versionadded:: 5.x - - .. seealso:: - Driver config :ref:`max-home-database-delay-ref` - """ - home_db_cache = self._pool.home_db_cache - if home_db_cache.enabled: - home_db_cache.clear() - async def _work( tx: AsyncManagedTransaction, diff --git a/src/neo4j/_sync/driver.py b/src/neo4j/_sync/driver.py index 9d895710..971f9be7 100644 --- a/src/neo4j/_sync/driver.py +++ b/src/neo4j/_sync/driver.py @@ -800,9 +800,6 @@ def example(driver: neo4j.Driver) -> int: :data:`None` (default) uses the database configured on the server side. - Depending on the :ref:`max-home-database-delay-ref` configuration, - propagation of changes to the server side default might not be - immediate. .. Note:: It is recommended to always specify the database explicitly @@ -1299,28 +1296,6 @@ def _get_server_info(self, session_config) -> ServerInfo: with self._session(session_config) as session: return session._get_server_info() - def force_home_database_resolution(self) -> None: - """Force the driver to resolve all home databases (again). - - The resolution is lazy and will only happen when the driver needs to - know the home database. - In practice, this means that the driver will flush the cache - configured by `max_home_database_delay`. - - This method is for instance useful when an application has changed a - user's home database, and the same application wants to pick up the - change in the next session while wanting to avoid setting - `max_home_database_delay` to `0` because of the performance penalty. - - .. versionadded:: 5.x - - .. seealso:: - Driver config :ref:`max-home-database-delay-ref` - """ - home_db_cache = self._pool.home_db_cache - if home_db_cache.enabled: - home_db_cache.clear() - def _work( tx: ManagedTransaction, From 6465c592cce2b4f1a62ced003d962c937b2cec34 Mon Sep 17 00:00:00 2001 From: Robsdedude Date: Mon, 25 Nov 2024 10:08:24 +0100 Subject: [PATCH 12/26] Improve cache pruning performance --- src/neo4j/_async/home_db_cache.py | 7 ++-- src/neo4j/_sync/home_db_cache.py | 7 ++-- tests/unit/async_/test_home_db_cache.py | 55 +++++++++++++++++++++++-- tests/unit/sync/test_home_db_cache.py | 55 +++++++++++++++++++++++-- 4 files changed, 108 insertions(+), 16 deletions(-) diff --git a/src/neo4j/_async/home_db_cache.py b/src/neo4j/_async/home_db_cache.py index bcef66e4..0c0617ae 100644 --- a/src/neo4j/_async/home_db_cache.py +++ b/src/neo4j/_async/home_db_cache.py @@ -18,7 +18,6 @@ from __future__ import annotations -import heapq import math import typing as t from time import monotonic @@ -109,11 +108,11 @@ def _clean(self, now: float | None = None) -> None: ) if self._max_size and len(self._cache) > self._max_size: self._cache = dict( - heapq.nlargest( - self._max_size, + sorted( self._cache.items(), key=lambda item: item[1][0], - ) + reverse=True, + )[: int(self._max_size * 0.9)] ) def __len__(self) -> int: diff --git a/src/neo4j/_sync/home_db_cache.py b/src/neo4j/_sync/home_db_cache.py index 3e6167d4..89ba8d75 100644 --- a/src/neo4j/_sync/home_db_cache.py +++ b/src/neo4j/_sync/home_db_cache.py @@ -18,7 +18,6 @@ from __future__ import annotations -import heapq import math import typing as t from time import monotonic @@ -109,11 +108,11 @@ def _clean(self, now: float | None = None) -> None: ) if self._max_size and len(self._cache) > self._max_size: self._cache = dict( - heapq.nlargest( - self._max_size, + sorted( self._cache.items(), key=lambda item: item[1][0], - ) + reverse=True, + )[: int(self._max_size * 0.9)] ) def __len__(self) -> int: diff --git a/tests/unit/async_/test_home_db_cache.py b/tests/unit/async_/test_home_db_cache.py index 9b5f33f3..ad033c2f 100644 --- a/tests/unit/async_/test_home_db_cache.py +++ b/tests/unit/async_/test_home_db_cache.py @@ -16,6 +16,7 @@ from __future__ import annotations +import time import typing as t from datetime import ( datetime, @@ -26,7 +27,10 @@ import pytest import pytz +from neo4j._async.config import AsyncPoolConfig from neo4j._async.home_db_cache import AsyncHomeDbCache +from neo4j._async.io._pool import AsyncNeo4jPool +from neo4j._conf import WorkspaceConfig from neo4j.time import DateTime @@ -147,11 +151,16 @@ def test_key_auth_equality(auth1: dict, auth2: dict) -> None: def _assert_entries( cache: AsyncHomeDbCache, expected_entries: t.Collection[tuple[TKey, str]], + allow_subset: bool = False, ) -> None: __tracebackhide__ = True - assert len(cache) == len(expected_entries) - for key, value in expected_entries: - assert cache.get(key) == value + if not allow_subset: + assert len(cache) == len(expected_entries) + for key, value in expected_entries: + assert cache.get(key) == value + else: + hits = sum(cache.get(key) == value for key, value in expected_entries) + assert hits == len(cache) def _force_cache_clean( @@ -229,7 +238,7 @@ def test_cache_max_size() -> None: cache.set(key, value) _force_cache_clean(cache) - _assert_entries(cache, entries) + _assert_entries(cache, entries, allow_subset=True) def test_cache_max_size_empty_cache() -> None: @@ -237,3 +246,41 @@ def test_cache_max_size_empty_cache() -> None: assert len(cache) == 0 _force_cache_clean(cache) assert len(cache) == 0 + + +def test_clean_up_time() -> None: + def get_default_cache(): + pool = AsyncNeo4jPool( + lambda: None, AsyncPoolConfig(), WorkspaceConfig(), None + ) + return pool.home_db_cache + + repetitions = 5 + scenario_timings = [] + + default_max_size = get_default_cache()._max_size + # Test assumes that by default the driver uses a home db cache only limited + # by its size. + assert default_max_size + for max_size, count in ( + # no pruning needed + (default_max_size * 10, default_max_size * 10), + # pruning needed + (default_max_size, default_max_size * 10), + ): + cache = AsyncHomeDbCache(max_size=max_size) + keys = [cache.compute_key(f"key{i}", None) for i in range(count)] + rep_timings = [] + for _ in range(repetitions): + t0 = time.perf_counter() + for key in keys: + cache.set(key, "value") + t1 = time.perf_counter() + rep_timings.append(t1 - t0) + scenario_timings.append(sum(rep_timings) / len(rep_timings)) + + # pruning shouldn't take more than 20 times the time of no pruning + # N.B., the pruning takes O(n * log(n)) where n is max_size. So to achieve + # this limit, either max_size needs to be sufficiently small or the pruning + # algorithm needs to be performant enough. + assert scenario_timings[1] <= 20 * scenario_timings[0] diff --git a/tests/unit/sync/test_home_db_cache.py b/tests/unit/sync/test_home_db_cache.py index 95b793f5..1c7aa977 100644 --- a/tests/unit/sync/test_home_db_cache.py +++ b/tests/unit/sync/test_home_db_cache.py @@ -16,6 +16,7 @@ from __future__ import annotations +import time import typing as t from datetime import ( datetime, @@ -26,7 +27,10 @@ import pytest import pytz +from neo4j._conf import WorkspaceConfig +from neo4j._sync.config import PoolConfig from neo4j._sync.home_db_cache import HomeDbCache +from neo4j._sync.io._pool import Neo4jPool from neo4j.time import DateTime @@ -147,11 +151,16 @@ def test_key_auth_equality(auth1: dict, auth2: dict) -> None: def _assert_entries( cache: HomeDbCache, expected_entries: t.Collection[tuple[TKey, str]], + allow_subset: bool = False, ) -> None: __tracebackhide__ = True - assert len(cache) == len(expected_entries) - for key, value in expected_entries: - assert cache.get(key) == value + if not allow_subset: + assert len(cache) == len(expected_entries) + for key, value in expected_entries: + assert cache.get(key) == value + else: + hits = sum(cache.get(key) == value for key, value in expected_entries) + assert hits == len(cache) def _force_cache_clean( @@ -229,7 +238,7 @@ def test_cache_max_size() -> None: cache.set(key, value) _force_cache_clean(cache) - _assert_entries(cache, entries) + _assert_entries(cache, entries, allow_subset=True) def test_cache_max_size_empty_cache() -> None: @@ -237,3 +246,41 @@ def test_cache_max_size_empty_cache() -> None: assert len(cache) == 0 _force_cache_clean(cache) assert len(cache) == 0 + + +def test_clean_up_time() -> None: + def get_default_cache(): + pool = Neo4jPool( + lambda: None, PoolConfig(), WorkspaceConfig(), None + ) + return pool.home_db_cache + + repetitions = 5 + scenario_timings = [] + + default_max_size = get_default_cache()._max_size + # Test assumes that by default the driver uses a home db cache only limited + # by its size. + assert default_max_size + for max_size, count in ( + # no pruning needed + (default_max_size * 10, default_max_size * 10), + # pruning needed + (default_max_size, default_max_size * 10), + ): + cache = HomeDbCache(max_size=max_size) + keys = [cache.compute_key(f"key{i}", None) for i in range(count)] + rep_timings = [] + for _ in range(repetitions): + t0 = time.perf_counter() + for key in keys: + cache.set(key, "value") + t1 = time.perf_counter() + rep_timings.append(t1 - t0) + scenario_timings.append(sum(rep_timings) / len(rep_timings)) + + # pruning shouldn't take more than 20 times the time of no pruning + # N.B., the pruning takes O(n * log(n)) where n is max_size. So to achieve + # this limit, either max_size needs to be sufficiently small or the pruning + # algorithm needs to be performant enough. + assert scenario_timings[1] <= 20 * scenario_timings[0] From c84d6fda443fb4036f92afac27f4efd7174b0861 Mon Sep 17 00:00:00 2001 From: Robsdedude Date: Wed, 11 Dec 2024 15:47:42 +0100 Subject: [PATCH 13/26] Fix Python backward compatibility of type aliases --- src/neo4j/_async/home_db_cache.py | 11 ++++++++--- src/neo4j/_sync/home_db_cache.py | 11 ++++++++--- 2 files changed, 16 insertions(+), 6 deletions(-) diff --git a/src/neo4j/_async/home_db_cache.py b/src/neo4j/_async/home_db_cache.py index 0c0617ae..ee8d8a30 100644 --- a/src/neo4j/_async/home_db_cache.py +++ b/src/neo4j/_async/home_db_cache.py @@ -26,9 +26,14 @@ if t.TYPE_CHECKING: - # TAuthKey = t.Tuple[t.Tuple[]] - TKey = str | tuple[tuple[str, t.Hashable], ...] | tuple[None] - TVal = tuple[float, str] + import typing_extensions as te + + TKey: te.TypeAlias = t.Union[ + str, + t.Tuple[t.Tuple[str, t.Hashable], ...], + t.Tuple[None], + ] + TVal: te.TypeAlias = t.Tuple[float, str] class AsyncHomeDbCache: diff --git a/src/neo4j/_sync/home_db_cache.py b/src/neo4j/_sync/home_db_cache.py index 89ba8d75..de100d07 100644 --- a/src/neo4j/_sync/home_db_cache.py +++ b/src/neo4j/_sync/home_db_cache.py @@ -26,9 +26,14 @@ if t.TYPE_CHECKING: - # TAuthKey = t.Tuple[t.Tuple[]] - TKey = str | tuple[tuple[str, t.Hashable], ...] | tuple[None] - TVal = tuple[float, str] + import typing_extensions as te + + TKey: te.TypeAlias = t.Union[ + str, + t.Tuple[t.Tuple[str, t.Hashable], ...], + t.Tuple[None], + ] + TVal: te.TypeAlias = t.Tuple[float, str] class HomeDbCache: From 3186caa4b15c711ea93288b107c3ecd7b6072b4c Mon Sep 17 00:00:00 2001 From: Robsdedude Date: Wed, 11 Dec 2024 15:48:34 +0100 Subject: [PATCH 14/26] Minor optimization --- src/neo4j/_async/work/workspace.py | 8 ++++++-- src/neo4j/_sync/work/workspace.py | 8 ++++++-- 2 files changed, 12 insertions(+), 4 deletions(-) diff --git a/src/neo4j/_async/work/workspace.py b/src/neo4j/_async/work/workspace.py index e7859107..2d491548 100644 --- a/src/neo4j/_async/work/workspace.py +++ b/src/neo4j/_async/work/workspace.py @@ -196,8 +196,12 @@ async def _connect(self, access_mode, auth=None, **acquire_kwargs) -> None: } acquire_kwargs_.update(acquire_kwargs) self._connection = await self._pool.acquire(**acquire_kwargs_) - if target_db.guessed and ( - not self._pool.ssr_enabled or not self._connection.ssr_enabled + if ( + target_db.guessed + and not self._pinned_database + and ( + not self._pool.ssr_enabled or not self._connection.ssr_enabled + ) ): # race condition: in the meantime, the pool added a connection, # which does not support SSR. diff --git a/src/neo4j/_sync/work/workspace.py b/src/neo4j/_sync/work/workspace.py index 8ada7cc5..29eb293a 100644 --- a/src/neo4j/_sync/work/workspace.py +++ b/src/neo4j/_sync/work/workspace.py @@ -193,8 +193,12 @@ def _connect(self, access_mode, auth=None, **acquire_kwargs) -> None: } acquire_kwargs_.update(acquire_kwargs) self._connection = self._pool.acquire(**acquire_kwargs_) - if target_db.guessed and ( - not self._pool.ssr_enabled or not self._connection.ssr_enabled + if ( + target_db.guessed + and not self._pinned_database + and ( + not self._pool.ssr_enabled or not self._connection.ssr_enabled + ) ): # race condition: in the meantime, the pool added a connection, # which does not support SSR. From ccaacf405f9e1c3028e2af7b4a35ec006de5c641 Mon Sep 17 00:00:00 2001 From: Robsdedude Date: Thu, 12 Dec 2024 13:42:54 +0100 Subject: [PATCH 15/26] API docs: fix typo --- docs/source/api.rst | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/source/api.rst b/docs/source/api.rst index 1220629d..7fb69b4b 100644 --- a/docs/source/api.rst +++ b/docs/source/api.rst @@ -1035,7 +1035,7 @@ Specifically, the following applies: all queries within that session are executed with the explicit database name 'movies' supplied. Any change to the user’s home database is reflected only in sessions created after such change takes effect. This - behavior may requires additional network communication. In clustered + behavior may require additional network communication. In clustered environments, it is strongly recommended to avoid a single point of failure. For instance, by ensuring that the connection URI resolves to multiple endpoints. For older Bolt protocol versions the behavior is the From 6e41355c77a98bff583910b77febea1711f5b87e Mon Sep 17 00:00:00 2001 From: Robsdedude Date: Thu, 12 Dec 2024 14:36:17 +0100 Subject: [PATCH 16/26] Increase tests stability * Increase timeouts for slow machines * Avoid non-canceled coroutines on asyncio.gather encountering an exception --- tests/_async_util.py | 34 +++++++++++++++++++ .../mixed/async_compat/test_concurrency.py | 14 ++++---- tests/unit/mixed/io/test_direct.py | 9 ++--- 3 files changed, 47 insertions(+), 10 deletions(-) create mode 100644 tests/_async_util.py diff --git a/tests/_async_util.py b/tests/_async_util.py new file mode 100644 index 00000000..8a032ad8 --- /dev/null +++ b/tests/_async_util.py @@ -0,0 +1,34 @@ +# Copyright (c) "Neo4j" +# Neo4j Sweden AB [https://neo4j.com] +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import asyncio + + +async def gather_cancel(*coros_or_futures): + """ + Return a future aggregating results from the given coroutines/futures. + + A thin wrapper around asyncio.gather that cancels all coroutines/futures + if any of them raises an exception. + """ + futures = [asyncio.ensure_future(coro) for coro in coros_or_futures] + try: + await asyncio.gather(*futures) + except: + for future in futures: + future.cancel() + await asyncio.gather(*futures, return_exceptions=True) + raise diff --git a/tests/unit/mixed/async_compat/test_concurrency.py b/tests/unit/mixed/async_compat/test_concurrency.py index 03356352..f4c17374 100644 --- a/tests/unit/mixed/async_compat/test_concurrency.py +++ b/tests/unit/mixed/async_compat/test_concurrency.py @@ -20,6 +20,8 @@ from neo4j._async_compat.concurrency import AsyncRLock +from ...._async_util import gather_cancel + @pytest.mark.asyncio async def test_async_r_lock(): @@ -36,7 +38,7 @@ async def worker(): assert counter == counter_ + 1 assert not lock.locked() - await asyncio.gather(worker(), worker(), worker()) + await gather_cancel(worker(), worker(), worker()) assert not lock.locked() @@ -52,7 +54,7 @@ async def worker(): assert lock.locked() assert not lock.locked() - await asyncio.gather(worker(), worker(), worker()) + await gather_cancel(worker(), worker(), worker()) assert not lock.locked() @@ -69,7 +71,7 @@ async def waiter(): assert not await lock.acquire(timeout=0.1) assert not lock.locked() - await asyncio.gather(blocker(), waiter()) + await gather_cancel(blocker(), waiter()) assert lock.locked() # blocker still owns it! @@ -90,7 +92,7 @@ async def waiter(): # blocker: lock.release() assert not lock.locked() - await asyncio.gather(blocker(), waiter()) + await gather_cancel(blocker(), waiter()) assert lock.locked() # waiter still owns it! @@ -162,7 +164,7 @@ async def waiter(): lock.release() assert not lock.locked() - await asyncio.gather(blocker(), waiter_non_blocking(), waiter()) + await gather_cancel(blocker(), waiter_non_blocking(), waiter()) assert lock.locked() # waiter_non_blocking still owns it! @@ -225,7 +227,7 @@ async def waiter_non_blocking(): awaits += 1 assert not lock.locked() - await asyncio.gather(blocker(), waiter_non_blocking()) + await gather_cancel(blocker(), waiter_non_blocking()) assert not lock.locked() diff --git a/tests/unit/mixed/io/test_direct.py b/tests/unit/mixed/io/test_direct.py index d943fccf..ed4a9428 100644 --- a/tests/unit/mixed/io/test_direct.py +++ b/tests/unit/mixed/io/test_direct.py @@ -33,6 +33,7 @@ from neo4j._deadline import Deadline from neo4j._sync.io._pool import AcquisitionAuth +from ...._async_util import gather_cancel from ...async_.io.test_direct import AsyncFakeBoltPool from ...async_.test_auth_management import ( static_auth_manager as static_async_auth_manager, @@ -193,7 +194,7 @@ async def acquire_release_conn( async def waiter(pool_, acquired_counter_, release_event_): nonlocal pre_populated_connections, connections - if not await acquired_counter_.wait(5, timeout=1): + if not await acquired_counter_.wait(5, timeout=5): raise RuntimeError("Acquire coroutines not fast enough") # The pool size should be 5, all are in-use self.assert_pool_size(address, 5, 0, pool_) @@ -205,7 +206,7 @@ async def waiter(pool_, acquired_counter_, release_event_): release_event_.set() # wait for all coroutines to release connections back to pool - if not await acquired_counter_.wait(10, timeout=5): + if not await acquired_counter_.wait(10, timeout=10): raise RuntimeError("Acquire coroutines not fast enough") # The pool size is still 5, but all are free self.assert_pool_size(address, 0, 5, pool_) @@ -234,7 +235,7 @@ async def waiter(pool_, acquired_counter_, release_event_): ) for _ in range(10) ] - await asyncio.gather( + await gather_cancel( waiter(pool, acquired_counter, release_event), *coroutines ) @@ -276,4 +277,4 @@ async def acquire2(pool_): async with AsyncFakeBoltPool( async_fake_connection_generator, (), max_connection_pool_size=1 ) as pool: - await asyncio.gather(acquire1(pool), acquire2(pool)) + await gather_cancel(acquire1(pool), acquire2(pool)) From 3a1ed8810d6a827915a69e4c7d2f80d4d0af886b Mon Sep 17 00:00:00 2001 From: Robsdedude Date: Fri, 13 Dec 2024 12:31:43 +0100 Subject: [PATCH 17/26] Improve logging: log session pinning database --- src/neo4j/_async/work/workspace.py | 3 +++ src/neo4j/_sync/work/workspace.py | 3 +++ 2 files changed, 6 insertions(+) diff --git a/src/neo4j/_async/work/workspace.py b/src/neo4j/_async/work/workspace.py index 2d491548..4f33760b 100644 --- a/src/neo4j/_async/work/workspace.py +++ b/src/neo4j/_async/work/workspace.py @@ -119,6 +119,9 @@ def _database_callback(database: str | None) -> None: return _database_callback def _set_pinned_database(self, database): + if self._pinned_database: + return + log.debug("[#0000] _: pinning database: %r", database) self._pinned_database = True self._config.database = database diff --git a/src/neo4j/_sync/work/workspace.py b/src/neo4j/_sync/work/workspace.py index 29eb293a..49a6870d 100644 --- a/src/neo4j/_sync/work/workspace.py +++ b/src/neo4j/_sync/work/workspace.py @@ -116,6 +116,9 @@ def _database_callback(database: str | None) -> None: return _database_callback def _set_pinned_database(self, database): + if self._pinned_database: + return + log.debug("[#0000] _: pinning database: %r", database) self._pinned_database = True self._config.database = database From 53e7fe343c2bc8cf8478a1e508b30093c92b4948 Mon Sep 17 00:00:00 2001 From: Robsdedude Date: Fri, 13 Dec 2024 12:32:10 +0100 Subject: [PATCH 18/26] Restore pinning compatibility with legacy bolt versions --- src/neo4j/_async/work/workspace.py | 8 ++++++-- src/neo4j/_sync/work/workspace.py | 8 ++++++-- 2 files changed, 12 insertions(+), 4 deletions(-) diff --git a/src/neo4j/_async/work/workspace.py b/src/neo4j/_async/work/workspace.py index 4f33760b..1fb2c136 100644 --- a/src/neo4j/_async/work/workspace.py +++ b/src/neo4j/_async/work/workspace.py @@ -226,12 +226,16 @@ async def _get_routing_target_database( acquire_auth: AcquisitionAuth, ssr_enabled: bool, ) -> AcquisitionDatabase: - if self._config.database is not None or self._pool.is_direct_pool: - self._set_pinned_database(self._config.database) + if ( + self._pinned_database + or self._config.database is not None + or self._pool.is_direct_pool + ): log.debug( "[#0000] _: routing towards fixed database: %s", self._config.database, ) + self._set_pinned_database(self._config.database) return AcquisitionDatabase(self._config.database) auth = acquire_auth.auth diff --git a/src/neo4j/_sync/work/workspace.py b/src/neo4j/_sync/work/workspace.py index 49a6870d..d9b0289d 100644 --- a/src/neo4j/_sync/work/workspace.py +++ b/src/neo4j/_sync/work/workspace.py @@ -223,12 +223,16 @@ def _get_routing_target_database( acquire_auth: AcquisitionAuth, ssr_enabled: bool, ) -> AcquisitionDatabase: - if self._config.database is not None or self._pool.is_direct_pool: - self._set_pinned_database(self._config.database) + if ( + self._pinned_database + or self._config.database is not None + or self._pool.is_direct_pool + ): log.debug( "[#0000] _: routing towards fixed database: %s", self._config.database, ) + self._set_pinned_database(self._config.database) return AcquisitionDatabase(self._config.database) auth = acquire_auth.auth From 5b29162e50517ff35a9d2ef1ba0e9c77cea1ef17 Mon Sep 17 00:00:00 2001 From: Robsdedude Date: Fri, 13 Dec 2024 12:32:38 +0100 Subject: [PATCH 19/26] Simplify code --- src/neo4j/_async/io/_pool.py | 7 +------ src/neo4j/_sync/io/_pool.py | 7 +------ 2 files changed, 2 insertions(+), 12 deletions(-) diff --git a/src/neo4j/_async/io/_pool.py b/src/neo4j/_async/io/_pool.py index 45ec6371..a6fd2226 100644 --- a/src/neo4j/_async/io/_pool.py +++ b/src/neo4j/_async/io/_pool.py @@ -1014,12 +1014,7 @@ async def update_routing_table( async def update_connection_pool(self, *, database): async with self.refresh_lock: - rt = await self.get_routing_table(database) - routing_tables = [rt] if rt is not None else [] - for db in self.routing_tables: - if db == database: - continue - routing_tables.append(self.routing_tables[db]) + routing_tables = list(self.routing_tables.values()) servers = set.union( *(rt.servers() for rt in routing_tables), diff --git a/src/neo4j/_sync/io/_pool.py b/src/neo4j/_sync/io/_pool.py index c8995725..916f9197 100644 --- a/src/neo4j/_sync/io/_pool.py +++ b/src/neo4j/_sync/io/_pool.py @@ -1011,12 +1011,7 @@ def update_routing_table( def update_connection_pool(self, *, database): with self.refresh_lock: - rt = self.get_routing_table(database) - routing_tables = [rt] if rt is not None else [] - for db in self.routing_tables: - if db == database: - continue - routing_tables.append(self.routing_tables[db]) + routing_tables = list(self.routing_tables.values()) servers = set.union( *(rt.servers() for rt in routing_tables), From 232f0e609352eb8192dfaab36316ca6eddd84f2f Mon Sep 17 00:00:00 2001 From: Robsdedude Date: Fri, 13 Dec 2024 14:25:02 +0100 Subject: [PATCH 20/26] Fix internal TestKit helper request --- testkitbackend/_async/requests.py | 5 ++++- testkitbackend/_sync/requests.py | 5 ++++- 2 files changed, 8 insertions(+), 2 deletions(-) diff --git a/testkitbackend/_async/requests.py b/testkitbackend/_async/requests.py index 3e26e389..08f00917 100644 --- a/testkitbackend/_async/requests.py +++ b/testkitbackend/_async/requests.py @@ -30,6 +30,7 @@ import neo4j.api import neo4j.auth_management from neo4j._async_compat.util import AsyncUtil +from neo4j._routing import RoutingTable from neo4j.auth_management import ( AsyncAuthManager, AsyncAuthManagers, @@ -991,7 +992,9 @@ async def get_routing_table(backend, data): driver_id = data["driverId"] database = data["database"] driver = backend.drivers[driver_id] - routing_table = driver._pool.routing_tables[database] + routing_table = await driver._pool.get_routing_table(database) + if routing_table is None: + routing_table = RoutingTable(database=database) response_data = { "database": routing_table.database, "ttl": routing_table.ttl, diff --git a/testkitbackend/_sync/requests.py b/testkitbackend/_sync/requests.py index 44a9233b..586616e7 100644 --- a/testkitbackend/_sync/requests.py +++ b/testkitbackend/_sync/requests.py @@ -30,6 +30,7 @@ import neo4j.api import neo4j.auth_management from neo4j._async_compat.util import Util +from neo4j._routing import RoutingTable from neo4j.auth_management import ( AuthManager, AuthManagers, @@ -991,7 +992,9 @@ def get_routing_table(backend, data): driver_id = data["driverId"] database = data["database"] driver = backend.drivers[driver_id] - routing_table = driver._pool.routing_tables[database] + routing_table = driver._pool.get_routing_table(database) + if routing_table is None: + routing_table = RoutingTable(database=database) response_data = { "database": routing_table.database, "ttl": routing_table.ttl, From debf8a4e96de6bdad4dbbc516ce46305d167ec73 Mon Sep 17 00:00:00 2001 From: Robsdedude Date: Tue, 17 Dec 2024 12:17:36 +0100 Subject: [PATCH 21/26] Clean-up and polish --- .gitattributes | 4 ++-- docs/source/api.rst | 2 +- src/neo4j/_async/home_db_cache.py | 11 ++++++++-- src/neo4j/_async/io/_pool.py | 19 +++++++++------- src/neo4j/_async/work/workspace.py | 11 +++++----- src/neo4j/_sync/home_db_cache.py | 11 ++++++++-- src/neo4j/_sync/io/_pool.py | 19 +++++++++------- src/neo4j/_sync/work/workspace.py | 11 +++++----- tests/_async_compat/__init__.py | 13 +++++++++++ tests/unit/async_/io/test_class_bolt3.py | 22 +++++++++++++++++++ tests/unit/async_/io/test_class_bolt4x0.py | 22 +++++++++++++++++++ tests/unit/async_/io/test_class_bolt4x1.py | 22 +++++++++++++++++++ tests/unit/async_/io/test_class_bolt4x2.py | 22 +++++++++++++++++++ tests/unit/async_/io/test_class_bolt4x3.py | 22 +++++++++++++++++++ tests/unit/async_/io/test_class_bolt4x4.py | 22 +++++++++++++++++++ tests/unit/async_/io/test_class_bolt5x0.py | 22 +++++++++++++++++++ tests/unit/async_/io/test_class_bolt5x1.py | 22 +++++++++++++++++++ tests/unit/async_/io/test_class_bolt5x2.py | 22 +++++++++++++++++++ tests/unit/async_/io/test_class_bolt5x3.py | 22 +++++++++++++++++++ tests/unit/async_/io/test_class_bolt5x4.py | 22 +++++++++++++++++++ tests/unit/async_/io/test_class_bolt5x5.py | 22 +++++++++++++++++++ tests/unit/async_/io/test_class_bolt5x6.py | 22 +++++++++++++++++++ tests/unit/async_/io/test_class_bolt5x7.py | 22 +++++++++++++++++++ tests/unit/async_/io/test_class_bolt5x8.py | 22 +++++++++++++++++++ tests/unit/async_/test_home_db_cache.py | 17 ++++++++++----- tests/unit/async_/work/test_result.py | 25 +++++++++++----------- tests/unit/async_/work/test_session.py | 13 ++++++----- tests/unit/sync/io/test_class_bolt3.py | 22 +++++++++++++++++++ tests/unit/sync/io/test_class_bolt4x0.py | 22 +++++++++++++++++++ tests/unit/sync/io/test_class_bolt4x1.py | 22 +++++++++++++++++++ tests/unit/sync/io/test_class_bolt4x2.py | 22 +++++++++++++++++++ tests/unit/sync/io/test_class_bolt4x3.py | 22 +++++++++++++++++++ tests/unit/sync/io/test_class_bolt4x4.py | 22 +++++++++++++++++++ tests/unit/sync/io/test_class_bolt5x0.py | 22 +++++++++++++++++++ tests/unit/sync/io/test_class_bolt5x1.py | 22 +++++++++++++++++++ tests/unit/sync/io/test_class_bolt5x2.py | 22 +++++++++++++++++++ tests/unit/sync/io/test_class_bolt5x3.py | 22 +++++++++++++++++++ tests/unit/sync/io/test_class_bolt5x4.py | 22 +++++++++++++++++++ tests/unit/sync/io/test_class_bolt5x5.py | 22 +++++++++++++++++++ tests/unit/sync/io/test_class_bolt5x6.py | 22 +++++++++++++++++++ tests/unit/sync/io/test_class_bolt5x7.py | 22 +++++++++++++++++++ tests/unit/sync/io/test_class_bolt5x8.py | 22 +++++++++++++++++++ tests/unit/sync/test_home_db_cache.py | 17 ++++++++++----- tests/unit/sync/work/test_result.py | 25 +++++++++++----------- tests/unit/sync/work/test_session.py | 13 ++++++----- 45 files changed, 794 insertions(+), 77 deletions(-) diff --git a/.gitattributes b/.gitattributes index ca24c0e4..35dc3fdc 100644 --- a/.gitattributes +++ b/.gitattributes @@ -1,5 +1,5 @@ # configure github not to display generated files /src/neo4j/_sync/** linguist-generated=true -/tests/unit/sync_/** linguist-generated=true -/tests/integration/sync_/** linguist-generated=true +/tests/unit/sync/** linguist-generated=true +/tests/integration/sync/** linguist-generated=true /testkitbackend/_sync/** linguist-generated=true diff --git a/docs/source/api.rst b/docs/source/api.rst index 7fb69b4b..281d8c92 100644 --- a/docs/source/api.rst +++ b/docs/source/api.rst @@ -165,7 +165,7 @@ Closing a driver will immediately shut down all connections in the pool. .. autoclass:: neo4j.Driver() :members: session, execute_query_bookmark_manager, encrypted, close, verify_connectivity, get_server_info, verify_authentication, - supports_session_auth, supports_multi_db, force_home_database_resolution + supports_session_auth, supports_multi_db .. method:: execute_query(query, parameters_=None,routing_=neo4j.RoutingControl.WRITE, database_=None, impersonated_user_=None, bookmark_manager_=self.execute_query_bookmark_manager, auth_=None, result_transformer_=Result.to_eager_result, **kwargs) diff --git a/src/neo4j/_async/home_db_cache.py b/src/neo4j/_async/home_db_cache.py index ee8d8a30..96c2850f 100644 --- a/src/neo4j/_async/home_db_cache.py +++ b/src/neo4j/_async/home_db_cache.py @@ -60,6 +60,11 @@ def __init__( f"got {max_size}" ) self._max_size = max_size + self._truncate_size = ( + min(max_size, int(0.01 * max_size * math.log(max_size))) + if max_size is not None + else None + ) def compute_key( self, @@ -106,7 +111,9 @@ def _clean(self, now: float | None = None) -> None: now = monotonic() if now is None else now if now - self._oldest_entry > self._ttl: self._cache = { - k: v for k, v in self._cache.items() if now - v[0] < self._ttl + k: v + for k, v in self._cache.items() + if now - v[0] < self._ttl * 0.9 } self._oldest_entry = min( (v[0] for v in self._cache.values()), default=now @@ -117,7 +124,7 @@ def _clean(self, now: float | None = None) -> None: self._cache.items(), key=lambda item: item[1][0], reverse=True, - )[: int(self._max_size * 0.9)] + )[: self._truncate_size] ) def __len__(self) -> int: diff --git a/src/neo4j/_async/io/_pool.py b/src/neo4j/_async/io/_pool.py index a6fd2226..be697ddb 100644 --- a/src/neo4j/_async/io/_pool.py +++ b/src/neo4j/_async/io/_pool.py @@ -105,13 +105,13 @@ def add_connection(self, connection): def remove_connection(self, connection): if self.feature_check(connection): if self.with_feature == 0: - raise ValueError( + raise RuntimeError( "No connections to be removed from feature tracker" ) self.with_feature -= 1 else: if self.without_feature == 0: - raise ValueError( + raise RuntimeError( "No connections to be removed from feature tracker" ) self.without_feature -= 1 @@ -143,7 +143,8 @@ def is_direct_pool(self) -> bool: ... @property def ssr_enabled(self) -> bool: - return self._ssr_feature_tracker.has_feature + with self.lock: + return self._ssr_feature_tracker.has_feature async def __aenter__(self): return self @@ -601,8 +602,8 @@ async def close(self): for address in list(self.connections) for connection in self.connections.pop(address, ()) ] - for connection in connections: - self._ssr_feature_tracker.remove_connection(connection) + for connection in connections: + self._ssr_feature_tracker.remove_connection(connection) await self._close_connections(connections) except TypeError: pass @@ -1012,7 +1013,7 @@ async def update_routing_table( log.error("Unable to retrieve routing information") raise ServiceUnavailable("Unable to retrieve routing information") - async def update_connection_pool(self, *, database): + async def update_connection_pool(self): async with self.refresh_lock: routing_tables = list(self.routing_tables.values()) @@ -1077,12 +1078,14 @@ async def ensure_routing_table_is_fresh( ) return False + database_request = database.name if not database.guessed else None + async def wrapped_database_callback(database: str | None) -> None: await AsyncUtil.callback(database_callback, database) - await self.update_connection_pool(database=database) + await self.update_connection_pool() await self.update_routing_table( - database=database.name if not database.guessed else None, + database=database_request, imp_user=imp_user, bookmarks=bookmarks, auth=auth, diff --git a/src/neo4j/_async/work/workspace.py b/src/neo4j/_async/work/workspace.py index 1fb2c136..e88c4a7f 100644 --- a/src/neo4j/_async/work/workspace.py +++ b/src/neo4j/_async/work/workspace.py @@ -104,14 +104,15 @@ async def __aenter__(self) -> AsyncWorkspace: async def __aexit__(self, exc_type, exc_value, traceback): await self.close() - def _make_db_resolution_callback(self) -> t.Callable[[str], None] | None: + def _make_db_resolution_callback( + self, + ) -> t.Callable[[str | None], None] | None: if self._pinned_database: return None def _database_callback(database: str | None) -> None: - if not self._pinned_database: - self._set_pinned_database(database) - if self._last_cache_key is None: + self._set_pinned_database(database) + if self._last_cache_key is None or database is None: return db_cache: AsyncHomeDbCache = self._pool.home_db_cache db_cache.set(self._last_cache_key, database) @@ -206,7 +207,7 @@ async def _connect(self, access_mode, auth=None, **acquire_kwargs) -> None: not self._pool.ssr_enabled or not self._connection.ssr_enabled ) ): - # race condition: in the meantime, the pool added a connection, + # race condition: in the meantime the pool added a connection # which does not support SSR. # => we need to fall back to explicit home database resolution log.debug( diff --git a/src/neo4j/_sync/home_db_cache.py b/src/neo4j/_sync/home_db_cache.py index de100d07..904fa662 100644 --- a/src/neo4j/_sync/home_db_cache.py +++ b/src/neo4j/_sync/home_db_cache.py @@ -60,6 +60,11 @@ def __init__( f"got {max_size}" ) self._max_size = max_size + self._truncate_size = ( + min(max_size, int(0.01 * max_size * math.log(max_size))) + if max_size is not None + else None + ) def compute_key( self, @@ -106,7 +111,9 @@ def _clean(self, now: float | None = None) -> None: now = monotonic() if now is None else now if now - self._oldest_entry > self._ttl: self._cache = { - k: v for k, v in self._cache.items() if now - v[0] < self._ttl + k: v + for k, v in self._cache.items() + if now - v[0] < self._ttl * 0.9 } self._oldest_entry = min( (v[0] for v in self._cache.values()), default=now @@ -117,7 +124,7 @@ def _clean(self, now: float | None = None) -> None: self._cache.items(), key=lambda item: item[1][0], reverse=True, - )[: int(self._max_size * 0.9)] + )[: self._truncate_size] ) def __len__(self) -> int: diff --git a/src/neo4j/_sync/io/_pool.py b/src/neo4j/_sync/io/_pool.py index 916f9197..c04c1165 100644 --- a/src/neo4j/_sync/io/_pool.py +++ b/src/neo4j/_sync/io/_pool.py @@ -102,13 +102,13 @@ def add_connection(self, connection): def remove_connection(self, connection): if self.feature_check(connection): if self.with_feature == 0: - raise ValueError( + raise RuntimeError( "No connections to be removed from feature tracker" ) self.with_feature -= 1 else: if self.without_feature == 0: - raise ValueError( + raise RuntimeError( "No connections to be removed from feature tracker" ) self.without_feature -= 1 @@ -140,7 +140,8 @@ def is_direct_pool(self) -> bool: ... @property def ssr_enabled(self) -> bool: - return self._ssr_feature_tracker.has_feature + with self.lock: + return self._ssr_feature_tracker.has_feature def __enter__(self): return self @@ -598,8 +599,8 @@ def close(self): for address in list(self.connections) for connection in self.connections.pop(address, ()) ] - for connection in connections: - self._ssr_feature_tracker.remove_connection(connection) + for connection in connections: + self._ssr_feature_tracker.remove_connection(connection) self._close_connections(connections) except TypeError: pass @@ -1009,7 +1010,7 @@ def update_routing_table( log.error("Unable to retrieve routing information") raise ServiceUnavailable("Unable to retrieve routing information") - def update_connection_pool(self, *, database): + def update_connection_pool(self): with self.refresh_lock: routing_tables = list(self.routing_tables.values()) @@ -1074,12 +1075,14 @@ def ensure_routing_table_is_fresh( ) return False + database_request = database.name if not database.guessed else None + def wrapped_database_callback(database: str | None) -> None: Util.callback(database_callback, database) - self.update_connection_pool(database=database) + self.update_connection_pool() self.update_routing_table( - database=database.name if not database.guessed else None, + database=database_request, imp_user=imp_user, bookmarks=bookmarks, auth=auth, diff --git a/src/neo4j/_sync/work/workspace.py b/src/neo4j/_sync/work/workspace.py index d9b0289d..39955001 100644 --- a/src/neo4j/_sync/work/workspace.py +++ b/src/neo4j/_sync/work/workspace.py @@ -101,14 +101,15 @@ def __enter__(self) -> Workspace: def __exit__(self, exc_type, exc_value, traceback): self.close() - def _make_db_resolution_callback(self) -> t.Callable[[str], None] | None: + def _make_db_resolution_callback( + self, + ) -> t.Callable[[str | None], None] | None: if self._pinned_database: return None def _database_callback(database: str | None) -> None: - if not self._pinned_database: - self._set_pinned_database(database) - if self._last_cache_key is None: + self._set_pinned_database(database) + if self._last_cache_key is None or database is None: return db_cache: HomeDbCache = self._pool.home_db_cache db_cache.set(self._last_cache_key, database) @@ -203,7 +204,7 @@ def _connect(self, access_mode, auth=None, **acquire_kwargs) -> None: not self._pool.ssr_enabled or not self._connection.ssr_enabled ) ): - # race condition: in the meantime, the pool added a connection, + # race condition: in the meantime the pool added a connection # which does not support SSR. # => we need to fall back to explicit home database resolution log.debug( diff --git a/tests/_async_compat/__init__.py b/tests/_async_compat/__init__.py index 67cd7a37..8170965a 100644 --- a/tests/_async_compat/__init__.py +++ b/tests/_async_compat/__init__.py @@ -14,6 +14,8 @@ # limitations under the License. +from functools import wraps as _wraps + from .mark_decorator import ( AsyncTestDecorators, mark_async_test, @@ -27,4 +29,15 @@ "TestDecorators", "mark_async_test", "mark_sync_test", + "wrap_async", ] + + +def wrap_async(func): + @_wraps(func) + async def wrapper(*args, **kwargs): # noqa: RUF029 + # [noqa] the hole point of this wrapper is to turn a sync function into + # an async one for testing purposes + return func(*args, **kwargs) + + return wrapper diff --git a/tests/unit/async_/io/test_class_bolt3.py b/tests/unit/async_/io/test_class_bolt3.py index 6442e8f9..509b6c76 100644 --- a/tests/unit/async_/io/test_class_bolt3.py +++ b/tests/unit/async_/io/test_class_bolt3.py @@ -569,3 +569,25 @@ def on_success(metadata): await connection.fetch_all() assert received_metadata == sent_metadata + + +@pytest.mark.parametrize("ssr_hint", (True, False, None)) +@mark_async_test +async def test_ssr_enabled(ssr_hint, fake_socket_pair): + address = neo4j.Address(("127.0.0.1", 7687)) + sockets = fake_socket_pair( + address, + packer_cls=AsyncBolt3.PACKER_CLS, + unpacker_cls=AsyncBolt3.UNPACKER_CLS, + ) + meta = {"server": "Neo4j/4.3.4"} + if ssr_hint is not None: + meta["hints"] = {"ssr.enabled": ssr_hint} + await sockets.server.send_message(b"\x70", meta) + await sockets.server.send_message(b"\x70", {}) + connection = AsyncBolt3( + address, sockets.client, AsyncPoolConfig.max_connection_lifetime + ) + assert connection.ssr_enabled is False + await connection.hello() + assert connection.ssr_enabled is False diff --git a/tests/unit/async_/io/test_class_bolt4x0.py b/tests/unit/async_/io/test_class_bolt4x0.py index 771f7e30..e3af55ad 100644 --- a/tests/unit/async_/io/test_class_bolt4x0.py +++ b/tests/unit/async_/io/test_class_bolt4x0.py @@ -660,3 +660,25 @@ def on_success(metadata): await connection.fetch_all() assert received_metadata == sent_metadata + + +@pytest.mark.parametrize("ssr_hint", (True, False, None)) +@mark_async_test +async def test_ssr_enabled(ssr_hint, fake_socket_pair): + address = neo4j.Address(("127.0.0.1", 7687)) + sockets = fake_socket_pair( + address, + packer_cls=AsyncBolt4x0.PACKER_CLS, + unpacker_cls=AsyncBolt4x0.UNPACKER_CLS, + ) + meta = {"server": "Neo4j/4.3.4"} + if ssr_hint is not None: + meta["hints"] = {"ssr.enabled": ssr_hint} + await sockets.server.send_message(b"\x70", meta) + await sockets.server.send_message(b"\x70", {}) + connection = AsyncBolt4x0( + address, sockets.client, AsyncPoolConfig.max_connection_lifetime + ) + assert connection.ssr_enabled is False + await connection.hello() + assert connection.ssr_enabled is False diff --git a/tests/unit/async_/io/test_class_bolt4x1.py b/tests/unit/async_/io/test_class_bolt4x1.py index cd37fce4..7f33fe01 100644 --- a/tests/unit/async_/io/test_class_bolt4x1.py +++ b/tests/unit/async_/io/test_class_bolt4x1.py @@ -682,3 +682,25 @@ def on_success(metadata): await connection.fetch_all() assert received_metadata == sent_metadata + + +@pytest.mark.parametrize("ssr_hint", (True, False, None)) +@mark_async_test +async def test_ssr_enabled(ssr_hint, fake_socket_pair): + address = neo4j.Address(("127.0.0.1", 7687)) + sockets = fake_socket_pair( + address, + packer_cls=AsyncBolt4x1.PACKER_CLS, + unpacker_cls=AsyncBolt4x1.UNPACKER_CLS, + ) + meta = {"server": "Neo4j/4.3.4"} + if ssr_hint is not None: + meta["hints"] = {"ssr.enabled": ssr_hint} + await sockets.server.send_message(b"\x70", meta) + await sockets.server.send_message(b"\x70", {}) + connection = AsyncBolt4x1( + address, sockets.client, AsyncPoolConfig.max_connection_lifetime + ) + assert connection.ssr_enabled is False + await connection.hello() + assert connection.ssr_enabled is False diff --git a/tests/unit/async_/io/test_class_bolt4x2.py b/tests/unit/async_/io/test_class_bolt4x2.py index 54180fa0..d243aef6 100644 --- a/tests/unit/async_/io/test_class_bolt4x2.py +++ b/tests/unit/async_/io/test_class_bolt4x2.py @@ -682,3 +682,25 @@ def on_success(metadata): await connection.fetch_all() assert received_metadata == sent_metadata + + +@pytest.mark.parametrize("ssr_hint", (True, False, None)) +@mark_async_test +async def test_ssr_enabled(ssr_hint, fake_socket_pair): + address = neo4j.Address(("127.0.0.1", 7687)) + sockets = fake_socket_pair( + address, + packer_cls=AsyncBolt4x2.PACKER_CLS, + unpacker_cls=AsyncBolt4x2.UNPACKER_CLS, + ) + meta = {"server": "Neo4j/4.3.4"} + if ssr_hint is not None: + meta["hints"] = {"ssr.enabled": ssr_hint} + await sockets.server.send_message(b"\x70", meta) + await sockets.server.send_message(b"\x70", {}) + connection = AsyncBolt4x2( + address, sockets.client, AsyncPoolConfig.max_connection_lifetime + ) + assert connection.ssr_enabled is False + await connection.hello() + assert connection.ssr_enabled is False diff --git a/tests/unit/async_/io/test_class_bolt4x3.py b/tests/unit/async_/io/test_class_bolt4x3.py index ba39bfbb..e1a23e01 100644 --- a/tests/unit/async_/io/test_class_bolt4x3.py +++ b/tests/unit/async_/io/test_class_bolt4x3.py @@ -711,3 +711,25 @@ def on_success(metadata): await connection.fetch_all() assert received_metadata == sent_metadata + + +@pytest.mark.parametrize("ssr_hint", (True, False, None)) +@mark_async_test +async def test_ssr_enabled(ssr_hint, fake_socket_pair): + address = neo4j.Address(("127.0.0.1", 7687)) + sockets = fake_socket_pair( + address, + packer_cls=AsyncBolt4x3.PACKER_CLS, + unpacker_cls=AsyncBolt4x3.UNPACKER_CLS, + ) + meta = {"server": "Neo4j/4.3.4"} + if ssr_hint is not None: + meta["hints"] = {"ssr.enabled": ssr_hint} + await sockets.server.send_message(b"\x70", meta) + await sockets.server.send_message(b"\x70", {}) + connection = AsyncBolt4x3( + address, sockets.client, AsyncPoolConfig.max_connection_lifetime + ) + assert connection.ssr_enabled is False + await connection.hello() + assert connection.ssr_enabled is False diff --git a/tests/unit/async_/io/test_class_bolt4x4.py b/tests/unit/async_/io/test_class_bolt4x4.py index b66f8b32..c24f06a9 100644 --- a/tests/unit/async_/io/test_class_bolt4x4.py +++ b/tests/unit/async_/io/test_class_bolt4x4.py @@ -671,3 +671,25 @@ def on_success(metadata): await connection.fetch_all() assert received_metadata == sent_metadata + + +@pytest.mark.parametrize("ssr_hint", (True, False, None)) +@mark_async_test +async def test_ssr_enabled(ssr_hint, fake_socket_pair): + address = neo4j.Address(("127.0.0.1", 7687)) + sockets = fake_socket_pair( + address, + packer_cls=AsyncBolt4x4.PACKER_CLS, + unpacker_cls=AsyncBolt4x4.UNPACKER_CLS, + ) + meta = {"server": "Neo4j/4.3.4"} + if ssr_hint is not None: + meta["hints"] = {"ssr.enabled": ssr_hint} + await sockets.server.send_message(b"\x70", meta) + await sockets.server.send_message(b"\x70", {}) + connection = AsyncBolt4x4( + address, sockets.client, AsyncPoolConfig.max_connection_lifetime + ) + assert connection.ssr_enabled is False + await connection.hello() + assert connection.ssr_enabled is False diff --git a/tests/unit/async_/io/test_class_bolt5x0.py b/tests/unit/async_/io/test_class_bolt5x0.py index 823d9ddb..8a5de715 100644 --- a/tests/unit/async_/io/test_class_bolt5x0.py +++ b/tests/unit/async_/io/test_class_bolt5x0.py @@ -735,3 +735,25 @@ def on_success(metadata): await connection.fetch_all() assert received_metadata == sent_metadata + + +@pytest.mark.parametrize("ssr_hint", (True, False, None)) +@mark_async_test +async def test_ssr_enabled(ssr_hint, fake_socket_pair): + address = neo4j.Address(("127.0.0.1", 7687)) + sockets = fake_socket_pair( + address, + packer_cls=AsyncBolt5x0.PACKER_CLS, + unpacker_cls=AsyncBolt5x0.UNPACKER_CLS, + ) + meta = {"server": "Neo4j/4.3.4"} + if ssr_hint is not None: + meta["hints"] = {"ssr.enabled": ssr_hint} + await sockets.server.send_message(b"\x70", meta) + await sockets.server.send_message(b"\x70", {}) + connection = AsyncBolt5x0( + address, sockets.client, AsyncPoolConfig.max_connection_lifetime + ) + assert connection.ssr_enabled is False + await connection.hello() + assert connection.ssr_enabled is False diff --git a/tests/unit/async_/io/test_class_bolt5x1.py b/tests/unit/async_/io/test_class_bolt5x1.py index 0eef369e..847ff059 100644 --- a/tests/unit/async_/io/test_class_bolt5x1.py +++ b/tests/unit/async_/io/test_class_bolt5x1.py @@ -789,3 +789,25 @@ def on_success(metadata): await connection.fetch_all() assert received_metadata == sent_metadata + + +@pytest.mark.parametrize("ssr_hint", (True, False, None)) +@mark_async_test +async def test_ssr_enabled(ssr_hint, fake_socket_pair): + address = neo4j.Address(("127.0.0.1", 7687)) + sockets = fake_socket_pair( + address, + packer_cls=AsyncBolt5x1.PACKER_CLS, + unpacker_cls=AsyncBolt5x1.UNPACKER_CLS, + ) + meta = {"server": "Neo4j/4.3.4"} + if ssr_hint is not None: + meta["hints"] = {"ssr.enabled": ssr_hint} + await sockets.server.send_message(b"\x70", meta) + await sockets.server.send_message(b"\x70", {}) + connection = AsyncBolt5x1( + address, sockets.client, AsyncPoolConfig.max_connection_lifetime + ) + assert connection.ssr_enabled is False + await connection.hello() + assert connection.ssr_enabled is False diff --git a/tests/unit/async_/io/test_class_bolt5x2.py b/tests/unit/async_/io/test_class_bolt5x2.py index 2e25ec1d..c4c08eda 100644 --- a/tests/unit/async_/io/test_class_bolt5x2.py +++ b/tests/unit/async_/io/test_class_bolt5x2.py @@ -826,3 +826,25 @@ def on_success(metadata): await connection.fetch_all() assert received_metadata == sent_metadata + + +@pytest.mark.parametrize("ssr_hint", (True, False, None)) +@mark_async_test +async def test_ssr_enabled(ssr_hint, fake_socket_pair): + address = neo4j.Address(("127.0.0.1", 7687)) + sockets = fake_socket_pair( + address, + packer_cls=AsyncBolt5x2.PACKER_CLS, + unpacker_cls=AsyncBolt5x2.UNPACKER_CLS, + ) + meta = {"server": "Neo4j/4.3.4"} + if ssr_hint is not None: + meta["hints"] = {"ssr.enabled": ssr_hint} + await sockets.server.send_message(b"\x70", meta) + await sockets.server.send_message(b"\x70", {}) + connection = AsyncBolt5x2( + address, sockets.client, AsyncPoolConfig.max_connection_lifetime + ) + assert connection.ssr_enabled is False + await connection.hello() + assert connection.ssr_enabled is False diff --git a/tests/unit/async_/io/test_class_bolt5x3.py b/tests/unit/async_/io/test_class_bolt5x3.py index 5d451025..e3a76563 100644 --- a/tests/unit/async_/io/test_class_bolt5x3.py +++ b/tests/unit/async_/io/test_class_bolt5x3.py @@ -713,3 +713,25 @@ def on_success(metadata): await connection.fetch_all() assert received_metadata == sent_metadata + + +@pytest.mark.parametrize("ssr_hint", (True, False, None)) +@mark_async_test +async def test_ssr_enabled(ssr_hint, fake_socket_pair): + address = neo4j.Address(("127.0.0.1", 7687)) + sockets = fake_socket_pair( + address, + packer_cls=AsyncBolt5x3.PACKER_CLS, + unpacker_cls=AsyncBolt5x3.UNPACKER_CLS, + ) + meta = {"server": "Neo4j/4.3.4"} + if ssr_hint is not None: + meta["hints"] = {"ssr.enabled": ssr_hint} + await sockets.server.send_message(b"\x70", meta) + await sockets.server.send_message(b"\x70", {}) + connection = AsyncBolt5x3( + address, sockets.client, AsyncPoolConfig.max_connection_lifetime + ) + assert connection.ssr_enabled is False + await connection.hello() + assert connection.ssr_enabled is False diff --git a/tests/unit/async_/io/test_class_bolt5x4.py b/tests/unit/async_/io/test_class_bolt5x4.py index 99d9c686..48e74114 100644 --- a/tests/unit/async_/io/test_class_bolt5x4.py +++ b/tests/unit/async_/io/test_class_bolt5x4.py @@ -718,3 +718,25 @@ def on_success(metadata): await connection.fetch_all() assert received_metadata == sent_metadata + + +@pytest.mark.parametrize("ssr_hint", (True, False, None)) +@mark_async_test +async def test_ssr_enabled(ssr_hint, fake_socket_pair): + address = neo4j.Address(("127.0.0.1", 7687)) + sockets = fake_socket_pair( + address, + packer_cls=AsyncBolt5x4.PACKER_CLS, + unpacker_cls=AsyncBolt5x4.UNPACKER_CLS, + ) + meta = {"server": "Neo4j/4.3.4"} + if ssr_hint is not None: + meta["hints"] = {"ssr.enabled": ssr_hint} + await sockets.server.send_message(b"\x70", meta) + await sockets.server.send_message(b"\x70", {}) + connection = AsyncBolt5x4( + address, sockets.client, AsyncPoolConfig.max_connection_lifetime + ) + assert connection.ssr_enabled is False + await connection.hello() + assert connection.ssr_enabled is False diff --git a/tests/unit/async_/io/test_class_bolt5x5.py b/tests/unit/async_/io/test_class_bolt5x5.py index d1ea0e51..60ea25ee 100644 --- a/tests/unit/async_/io/test_class_bolt5x5.py +++ b/tests/unit/async_/io/test_class_bolt5x5.py @@ -756,3 +756,25 @@ def extend_diag_record(r): } assert received_metadata == expected_metadata + + +@pytest.mark.parametrize("ssr_hint", (True, False, None)) +@mark_async_test +async def test_ssr_enabled(ssr_hint, fake_socket_pair): + address = neo4j.Address(("127.0.0.1", 7687)) + sockets = fake_socket_pair( + address, + packer_cls=AsyncBolt5x5.PACKER_CLS, + unpacker_cls=AsyncBolt5x5.UNPACKER_CLS, + ) + meta = {"server": "Neo4j/4.3.4"} + if ssr_hint is not None: + meta["hints"] = {"ssr.enabled": ssr_hint} + await sockets.server.send_message(b"\x70", meta) + await sockets.server.send_message(b"\x70", {}) + connection = AsyncBolt5x5( + address, sockets.client, AsyncPoolConfig.max_connection_lifetime + ) + assert connection.ssr_enabled is False + await connection.hello() + assert connection.ssr_enabled is False diff --git a/tests/unit/async_/io/test_class_bolt5x6.py b/tests/unit/async_/io/test_class_bolt5x6.py index 533af97e..1f11cf75 100644 --- a/tests/unit/async_/io/test_class_bolt5x6.py +++ b/tests/unit/async_/io/test_class_bolt5x6.py @@ -760,3 +760,25 @@ def extend_diag_record(r): } assert received_metadata == expected_metadata + + +@pytest.mark.parametrize("ssr_hint", (True, False, None)) +@mark_async_test +async def test_ssr_enabled(ssr_hint, fake_socket_pair): + address = neo4j.Address(("127.0.0.1", 7687)) + sockets = fake_socket_pair( + address, + packer_cls=AsyncBolt5x6.PACKER_CLS, + unpacker_cls=AsyncBolt5x6.UNPACKER_CLS, + ) + meta = {"server": "Neo4j/4.3.4"} + if ssr_hint is not None: + meta["hints"] = {"ssr.enabled": ssr_hint} + await sockets.server.send_message(b"\x70", meta) + await sockets.server.send_message(b"\x70", {}) + connection = AsyncBolt5x6( + address, sockets.client, AsyncPoolConfig.max_connection_lifetime + ) + assert connection.ssr_enabled is False + await connection.hello() + assert connection.ssr_enabled is False diff --git a/tests/unit/async_/io/test_class_bolt5x7.py b/tests/unit/async_/io/test_class_bolt5x7.py index 3758157a..09752de1 100644 --- a/tests/unit/async_/io/test_class_bolt5x7.py +++ b/tests/unit/async_/io/test_class_bolt5x7.py @@ -848,3 +848,25 @@ def _build_error_hierarchy_metadata(diag_records_metadata): if r is not ...: current_root["diagnostic_record"] = r return metadata + + +@pytest.mark.parametrize("ssr_hint", (True, False, None)) +@mark_async_test +async def test_ssr_enabled(ssr_hint, fake_socket_pair): + address = neo4j.Address(("127.0.0.1", 7687)) + sockets = fake_socket_pair( + address, + packer_cls=AsyncBolt5x7.PACKER_CLS, + unpacker_cls=AsyncBolt5x7.UNPACKER_CLS, + ) + meta = {"server": "Neo4j/4.3.4"} + if ssr_hint is not None: + meta["hints"] = {"ssr.enabled": ssr_hint} + await sockets.server.send_message(b"\x70", meta) + await sockets.server.send_message(b"\x70", {}) + connection = AsyncBolt5x7( + address, sockets.client, AsyncPoolConfig.max_connection_lifetime + ) + assert connection.ssr_enabled is False + await connection.hello() + assert connection.ssr_enabled is False diff --git a/tests/unit/async_/io/test_class_bolt5x8.py b/tests/unit/async_/io/test_class_bolt5x8.py index e3d572d1..6a105d1e 100644 --- a/tests/unit/async_/io/test_class_bolt5x8.py +++ b/tests/unit/async_/io/test_class_bolt5x8.py @@ -848,3 +848,25 @@ def _build_error_hierarchy_metadata(diag_records_metadata): if r is not ...: current_root["diagnostic_record"] = r return metadata + + +@pytest.mark.parametrize("ssr_hint", (True, False, None)) +@mark_async_test +async def test_ssr_enabled(ssr_hint, fake_socket_pair): + address = neo4j.Address(("127.0.0.1", 7687)) + sockets = fake_socket_pair( + address, + packer_cls=AsyncBolt5x8.PACKER_CLS, + unpacker_cls=AsyncBolt5x8.UNPACKER_CLS, + ) + meta = {"server": "Neo4j/4.3.4"} + if ssr_hint is not None: + meta["hints"] = {"ssr.enabled": ssr_hint} + await sockets.server.send_message(b"\x70", meta) + await sockets.server.send_message(b"\x70", {}) + connection = AsyncBolt5x8( + address, sockets.client, AsyncPoolConfig.max_connection_lifetime + ) + assert connection.ssr_enabled is False + await connection.hello() + assert connection.ssr_enabled is bool(ssr_hint) diff --git a/tests/unit/async_/test_home_db_cache.py b/tests/unit/async_/test_home_db_cache.py index ad033c2f..fe644ed1 100644 --- a/tests/unit/async_/test_home_db_cache.py +++ b/tests/unit/async_/test_home_db_cache.py @@ -16,6 +16,7 @@ from __future__ import annotations +import math import time import typing as t from datetime import ( @@ -258,10 +259,15 @@ def get_default_cache(): repetitions = 5 scenario_timings = [] - default_max_size = get_default_cache()._max_size # Test assumes that by default the driver uses a home db cache only limited # by its size. - assert default_max_size + default_cache = get_default_cache() + default_max_size = default_cache._max_size + assert isinstance(default_max_size, int) + # If ttl ever get used, this test needs to be updated to also test pruning + # by TTL. + assert math.isinf(default_cache._ttl) and default_cache._ttl > 0 + for max_size, count in ( # no pruning needed (default_max_size * 10, default_max_size * 10), @@ -280,7 +286,8 @@ def get_default_cache(): scenario_timings.append(sum(rep_timings) / len(rep_timings)) # pruning shouldn't take more than 20 times the time of no pruning - # N.B., the pruning takes O(n * log(n)) where n is max_size. So to achieve - # this limit, either max_size needs to be sufficiently small or the pruning - # algorithm needs to be performant enough. + # N.B., the pruning takes O(n * log(n)) where n is max_size. By only + # pruning O(n * log(n)) elements, we get an amortized pruning overhead of + # O(1) (as long as max_size is small enough to be able to choose a positive + # pruning size). assert scenario_timings[1] <= 20 * scenario_timings[0] diff --git a/tests/unit/async_/work/test_result.py b/tests/unit/async_/work/test_result.py index 01e05443..fb71105f 100644 --- a/tests/unit/async_/work/test_result.py +++ b/tests/unit/async_/work/test_result.py @@ -60,7 +60,10 @@ Neo4jWarning, ) -from ...._async_compat import mark_async_test +from ...._async_compat import ( + mark_async_test, + wrap_async, +) if t.TYPE_CHECKING: @@ -1424,23 +1427,21 @@ async def test_notification_logging( assert caplog.messages[0] == expected_message -@pytest.mark.parametrize("async_cb", (True, False)) +@pytest.mark.parametrize( + "async_cb", + (True, False) if AsyncUtil.is_async_code else (False,), +) @pytest.mark.parametrize("resolved_db", (..., None, "resolved_db")) @mark_async_test async def test_on_database_callback(async_cb, resolved_db): cb_calls = [] - if async_cb: - - async def db_callback(db): - nonlocal cb_calls - cb_calls.append(db) - - else: + def db_callback(db): + nonlocal cb_calls + cb_calls.append(db) - def db_callback(db): - nonlocal cb_calls - cb_calls.append(db) + if async_cb: + db_callback = wrap_async(db_callback) run_meta = {} if resolved_db is not ...: diff --git a/tests/unit/async_/work/test_session.py b/tests/unit/async_/work/test_session.py index 781f2402..06ba7046 100644 --- a/tests/unit/async_/work/test_session.py +++ b/tests/unit/async_/work/test_session.py @@ -435,12 +435,12 @@ async def work(tx): assert call.kwargs["parameters"] == expected_params -@pytest.mark.parametrize("db", (None, "adb")[:1]) -@pytest.mark.parametrize("routing", (True, False)[:1]) +@pytest.mark.parametrize("db", (None, "adb")) +@pytest.mark.parametrize("routing", (True, False)) # no home db resolution when connected to Neo4j 4.3 or earlier -@pytest.mark.parametrize("home_db_gets_resolved", (True, False)[:1]) +@pytest.mark.parametrize("home_db_gets_resolved", (True, False)) @pytest.mark.parametrize( - "additional_session_bookmarks", (None, ["session", "bookmarks"])[:1] + "additional_session_bookmarks", (None, ["session", "bookmarks"]) ) @mark_async_test async def test_with_bookmark_manager( @@ -720,7 +720,10 @@ async def work(_): @pytest.mark.parametrize("imp_user", (None, "imp_user")) @pytest.mark.parametrize( "auth", - (None, Auth(scheme="magic-auth", principal=None, credentials="tada")), + ( + None, + Auth(scheme="magic-auth", principal=None, credentials="tada"), + ), ) @mark_async_test async def test_uses_home_db_cache_when_expected( diff --git a/tests/unit/sync/io/test_class_bolt3.py b/tests/unit/sync/io/test_class_bolt3.py index 980694bb..af3d3c6b 100644 --- a/tests/unit/sync/io/test_class_bolt3.py +++ b/tests/unit/sync/io/test_class_bolt3.py @@ -569,3 +569,25 @@ def on_success(metadata): connection.fetch_all() assert received_metadata == sent_metadata + + +@pytest.mark.parametrize("ssr_hint", (True, False, None)) +@mark_sync_test +def test_ssr_enabled(ssr_hint, fake_socket_pair): + address = neo4j.Address(("127.0.0.1", 7687)) + sockets = fake_socket_pair( + address, + packer_cls=Bolt3.PACKER_CLS, + unpacker_cls=Bolt3.UNPACKER_CLS, + ) + meta = {"server": "Neo4j/4.3.4"} + if ssr_hint is not None: + meta["hints"] = {"ssr.enabled": ssr_hint} + sockets.server.send_message(b"\x70", meta) + sockets.server.send_message(b"\x70", {}) + connection = Bolt3( + address, sockets.client, PoolConfig.max_connection_lifetime + ) + assert connection.ssr_enabled is False + connection.hello() + assert connection.ssr_enabled is False diff --git a/tests/unit/sync/io/test_class_bolt4x0.py b/tests/unit/sync/io/test_class_bolt4x0.py index be3f4499..f9dfef4b 100644 --- a/tests/unit/sync/io/test_class_bolt4x0.py +++ b/tests/unit/sync/io/test_class_bolt4x0.py @@ -660,3 +660,25 @@ def on_success(metadata): connection.fetch_all() assert received_metadata == sent_metadata + + +@pytest.mark.parametrize("ssr_hint", (True, False, None)) +@mark_sync_test +def test_ssr_enabled(ssr_hint, fake_socket_pair): + address = neo4j.Address(("127.0.0.1", 7687)) + sockets = fake_socket_pair( + address, + packer_cls=Bolt4x0.PACKER_CLS, + unpacker_cls=Bolt4x0.UNPACKER_CLS, + ) + meta = {"server": "Neo4j/4.3.4"} + if ssr_hint is not None: + meta["hints"] = {"ssr.enabled": ssr_hint} + sockets.server.send_message(b"\x70", meta) + sockets.server.send_message(b"\x70", {}) + connection = Bolt4x0( + address, sockets.client, PoolConfig.max_connection_lifetime + ) + assert connection.ssr_enabled is False + connection.hello() + assert connection.ssr_enabled is False diff --git a/tests/unit/sync/io/test_class_bolt4x1.py b/tests/unit/sync/io/test_class_bolt4x1.py index 77b54513..219a9fda 100644 --- a/tests/unit/sync/io/test_class_bolt4x1.py +++ b/tests/unit/sync/io/test_class_bolt4x1.py @@ -682,3 +682,25 @@ def on_success(metadata): connection.fetch_all() assert received_metadata == sent_metadata + + +@pytest.mark.parametrize("ssr_hint", (True, False, None)) +@mark_sync_test +def test_ssr_enabled(ssr_hint, fake_socket_pair): + address = neo4j.Address(("127.0.0.1", 7687)) + sockets = fake_socket_pair( + address, + packer_cls=Bolt4x1.PACKER_CLS, + unpacker_cls=Bolt4x1.UNPACKER_CLS, + ) + meta = {"server": "Neo4j/4.3.4"} + if ssr_hint is not None: + meta["hints"] = {"ssr.enabled": ssr_hint} + sockets.server.send_message(b"\x70", meta) + sockets.server.send_message(b"\x70", {}) + connection = Bolt4x1( + address, sockets.client, PoolConfig.max_connection_lifetime + ) + assert connection.ssr_enabled is False + connection.hello() + assert connection.ssr_enabled is False diff --git a/tests/unit/sync/io/test_class_bolt4x2.py b/tests/unit/sync/io/test_class_bolt4x2.py index 65525f9e..944f7c28 100644 --- a/tests/unit/sync/io/test_class_bolt4x2.py +++ b/tests/unit/sync/io/test_class_bolt4x2.py @@ -682,3 +682,25 @@ def on_success(metadata): connection.fetch_all() assert received_metadata == sent_metadata + + +@pytest.mark.parametrize("ssr_hint", (True, False, None)) +@mark_sync_test +def test_ssr_enabled(ssr_hint, fake_socket_pair): + address = neo4j.Address(("127.0.0.1", 7687)) + sockets = fake_socket_pair( + address, + packer_cls=Bolt4x2.PACKER_CLS, + unpacker_cls=Bolt4x2.UNPACKER_CLS, + ) + meta = {"server": "Neo4j/4.3.4"} + if ssr_hint is not None: + meta["hints"] = {"ssr.enabled": ssr_hint} + sockets.server.send_message(b"\x70", meta) + sockets.server.send_message(b"\x70", {}) + connection = Bolt4x2( + address, sockets.client, PoolConfig.max_connection_lifetime + ) + assert connection.ssr_enabled is False + connection.hello() + assert connection.ssr_enabled is False diff --git a/tests/unit/sync/io/test_class_bolt4x3.py b/tests/unit/sync/io/test_class_bolt4x3.py index a4a61ba1..2e53fc42 100644 --- a/tests/unit/sync/io/test_class_bolt4x3.py +++ b/tests/unit/sync/io/test_class_bolt4x3.py @@ -711,3 +711,25 @@ def on_success(metadata): connection.fetch_all() assert received_metadata == sent_metadata + + +@pytest.mark.parametrize("ssr_hint", (True, False, None)) +@mark_sync_test +def test_ssr_enabled(ssr_hint, fake_socket_pair): + address = neo4j.Address(("127.0.0.1", 7687)) + sockets = fake_socket_pair( + address, + packer_cls=Bolt4x3.PACKER_CLS, + unpacker_cls=Bolt4x3.UNPACKER_CLS, + ) + meta = {"server": "Neo4j/4.3.4"} + if ssr_hint is not None: + meta["hints"] = {"ssr.enabled": ssr_hint} + sockets.server.send_message(b"\x70", meta) + sockets.server.send_message(b"\x70", {}) + connection = Bolt4x3( + address, sockets.client, PoolConfig.max_connection_lifetime + ) + assert connection.ssr_enabled is False + connection.hello() + assert connection.ssr_enabled is False diff --git a/tests/unit/sync/io/test_class_bolt4x4.py b/tests/unit/sync/io/test_class_bolt4x4.py index 692f64f8..9378e8bf 100644 --- a/tests/unit/sync/io/test_class_bolt4x4.py +++ b/tests/unit/sync/io/test_class_bolt4x4.py @@ -671,3 +671,25 @@ def on_success(metadata): connection.fetch_all() assert received_metadata == sent_metadata + + +@pytest.mark.parametrize("ssr_hint", (True, False, None)) +@mark_sync_test +def test_ssr_enabled(ssr_hint, fake_socket_pair): + address = neo4j.Address(("127.0.0.1", 7687)) + sockets = fake_socket_pair( + address, + packer_cls=Bolt4x4.PACKER_CLS, + unpacker_cls=Bolt4x4.UNPACKER_CLS, + ) + meta = {"server": "Neo4j/4.3.4"} + if ssr_hint is not None: + meta["hints"] = {"ssr.enabled": ssr_hint} + sockets.server.send_message(b"\x70", meta) + sockets.server.send_message(b"\x70", {}) + connection = Bolt4x4( + address, sockets.client, PoolConfig.max_connection_lifetime + ) + assert connection.ssr_enabled is False + connection.hello() + assert connection.ssr_enabled is False diff --git a/tests/unit/sync/io/test_class_bolt5x0.py b/tests/unit/sync/io/test_class_bolt5x0.py index 5bc3e2c2..2390112d 100644 --- a/tests/unit/sync/io/test_class_bolt5x0.py +++ b/tests/unit/sync/io/test_class_bolt5x0.py @@ -735,3 +735,25 @@ def on_success(metadata): connection.fetch_all() assert received_metadata == sent_metadata + + +@pytest.mark.parametrize("ssr_hint", (True, False, None)) +@mark_sync_test +def test_ssr_enabled(ssr_hint, fake_socket_pair): + address = neo4j.Address(("127.0.0.1", 7687)) + sockets = fake_socket_pair( + address, + packer_cls=Bolt5x0.PACKER_CLS, + unpacker_cls=Bolt5x0.UNPACKER_CLS, + ) + meta = {"server": "Neo4j/4.3.4"} + if ssr_hint is not None: + meta["hints"] = {"ssr.enabled": ssr_hint} + sockets.server.send_message(b"\x70", meta) + sockets.server.send_message(b"\x70", {}) + connection = Bolt5x0( + address, sockets.client, PoolConfig.max_connection_lifetime + ) + assert connection.ssr_enabled is False + connection.hello() + assert connection.ssr_enabled is False diff --git a/tests/unit/sync/io/test_class_bolt5x1.py b/tests/unit/sync/io/test_class_bolt5x1.py index 4376e39e..7b2804c4 100644 --- a/tests/unit/sync/io/test_class_bolt5x1.py +++ b/tests/unit/sync/io/test_class_bolt5x1.py @@ -789,3 +789,25 @@ def on_success(metadata): connection.fetch_all() assert received_metadata == sent_metadata + + +@pytest.mark.parametrize("ssr_hint", (True, False, None)) +@mark_sync_test +def test_ssr_enabled(ssr_hint, fake_socket_pair): + address = neo4j.Address(("127.0.0.1", 7687)) + sockets = fake_socket_pair( + address, + packer_cls=Bolt5x1.PACKER_CLS, + unpacker_cls=Bolt5x1.UNPACKER_CLS, + ) + meta = {"server": "Neo4j/4.3.4"} + if ssr_hint is not None: + meta["hints"] = {"ssr.enabled": ssr_hint} + sockets.server.send_message(b"\x70", meta) + sockets.server.send_message(b"\x70", {}) + connection = Bolt5x1( + address, sockets.client, PoolConfig.max_connection_lifetime + ) + assert connection.ssr_enabled is False + connection.hello() + assert connection.ssr_enabled is False diff --git a/tests/unit/sync/io/test_class_bolt5x2.py b/tests/unit/sync/io/test_class_bolt5x2.py index f2d0db48..165d1776 100644 --- a/tests/unit/sync/io/test_class_bolt5x2.py +++ b/tests/unit/sync/io/test_class_bolt5x2.py @@ -826,3 +826,25 @@ def on_success(metadata): connection.fetch_all() assert received_metadata == sent_metadata + + +@pytest.mark.parametrize("ssr_hint", (True, False, None)) +@mark_sync_test +def test_ssr_enabled(ssr_hint, fake_socket_pair): + address = neo4j.Address(("127.0.0.1", 7687)) + sockets = fake_socket_pair( + address, + packer_cls=Bolt5x2.PACKER_CLS, + unpacker_cls=Bolt5x2.UNPACKER_CLS, + ) + meta = {"server": "Neo4j/4.3.4"} + if ssr_hint is not None: + meta["hints"] = {"ssr.enabled": ssr_hint} + sockets.server.send_message(b"\x70", meta) + sockets.server.send_message(b"\x70", {}) + connection = Bolt5x2( + address, sockets.client, PoolConfig.max_connection_lifetime + ) + assert connection.ssr_enabled is False + connection.hello() + assert connection.ssr_enabled is False diff --git a/tests/unit/sync/io/test_class_bolt5x3.py b/tests/unit/sync/io/test_class_bolt5x3.py index fecd4d88..d0d17131 100644 --- a/tests/unit/sync/io/test_class_bolt5x3.py +++ b/tests/unit/sync/io/test_class_bolt5x3.py @@ -713,3 +713,25 @@ def on_success(metadata): connection.fetch_all() assert received_metadata == sent_metadata + + +@pytest.mark.parametrize("ssr_hint", (True, False, None)) +@mark_sync_test +def test_ssr_enabled(ssr_hint, fake_socket_pair): + address = neo4j.Address(("127.0.0.1", 7687)) + sockets = fake_socket_pair( + address, + packer_cls=Bolt5x3.PACKER_CLS, + unpacker_cls=Bolt5x3.UNPACKER_CLS, + ) + meta = {"server": "Neo4j/4.3.4"} + if ssr_hint is not None: + meta["hints"] = {"ssr.enabled": ssr_hint} + sockets.server.send_message(b"\x70", meta) + sockets.server.send_message(b"\x70", {}) + connection = Bolt5x3( + address, sockets.client, PoolConfig.max_connection_lifetime + ) + assert connection.ssr_enabled is False + connection.hello() + assert connection.ssr_enabled is False diff --git a/tests/unit/sync/io/test_class_bolt5x4.py b/tests/unit/sync/io/test_class_bolt5x4.py index 7740449d..ea938cc2 100644 --- a/tests/unit/sync/io/test_class_bolt5x4.py +++ b/tests/unit/sync/io/test_class_bolt5x4.py @@ -718,3 +718,25 @@ def on_success(metadata): connection.fetch_all() assert received_metadata == sent_metadata + + +@pytest.mark.parametrize("ssr_hint", (True, False, None)) +@mark_sync_test +def test_ssr_enabled(ssr_hint, fake_socket_pair): + address = neo4j.Address(("127.0.0.1", 7687)) + sockets = fake_socket_pair( + address, + packer_cls=Bolt5x4.PACKER_CLS, + unpacker_cls=Bolt5x4.UNPACKER_CLS, + ) + meta = {"server": "Neo4j/4.3.4"} + if ssr_hint is not None: + meta["hints"] = {"ssr.enabled": ssr_hint} + sockets.server.send_message(b"\x70", meta) + sockets.server.send_message(b"\x70", {}) + connection = Bolt5x4( + address, sockets.client, PoolConfig.max_connection_lifetime + ) + assert connection.ssr_enabled is False + connection.hello() + assert connection.ssr_enabled is False diff --git a/tests/unit/sync/io/test_class_bolt5x5.py b/tests/unit/sync/io/test_class_bolt5x5.py index c301de17..e5cc6e74 100644 --- a/tests/unit/sync/io/test_class_bolt5x5.py +++ b/tests/unit/sync/io/test_class_bolt5x5.py @@ -756,3 +756,25 @@ def extend_diag_record(r): } assert received_metadata == expected_metadata + + +@pytest.mark.parametrize("ssr_hint", (True, False, None)) +@mark_sync_test +def test_ssr_enabled(ssr_hint, fake_socket_pair): + address = neo4j.Address(("127.0.0.1", 7687)) + sockets = fake_socket_pair( + address, + packer_cls=Bolt5x5.PACKER_CLS, + unpacker_cls=Bolt5x5.UNPACKER_CLS, + ) + meta = {"server": "Neo4j/4.3.4"} + if ssr_hint is not None: + meta["hints"] = {"ssr.enabled": ssr_hint} + sockets.server.send_message(b"\x70", meta) + sockets.server.send_message(b"\x70", {}) + connection = Bolt5x5( + address, sockets.client, PoolConfig.max_connection_lifetime + ) + assert connection.ssr_enabled is False + connection.hello() + assert connection.ssr_enabled is False diff --git a/tests/unit/sync/io/test_class_bolt5x6.py b/tests/unit/sync/io/test_class_bolt5x6.py index 1f61f05e..a472ef5f 100644 --- a/tests/unit/sync/io/test_class_bolt5x6.py +++ b/tests/unit/sync/io/test_class_bolt5x6.py @@ -760,3 +760,25 @@ def extend_diag_record(r): } assert received_metadata == expected_metadata + + +@pytest.mark.parametrize("ssr_hint", (True, False, None)) +@mark_sync_test +def test_ssr_enabled(ssr_hint, fake_socket_pair): + address = neo4j.Address(("127.0.0.1", 7687)) + sockets = fake_socket_pair( + address, + packer_cls=Bolt5x6.PACKER_CLS, + unpacker_cls=Bolt5x6.UNPACKER_CLS, + ) + meta = {"server": "Neo4j/4.3.4"} + if ssr_hint is not None: + meta["hints"] = {"ssr.enabled": ssr_hint} + sockets.server.send_message(b"\x70", meta) + sockets.server.send_message(b"\x70", {}) + connection = Bolt5x6( + address, sockets.client, PoolConfig.max_connection_lifetime + ) + assert connection.ssr_enabled is False + connection.hello() + assert connection.ssr_enabled is False diff --git a/tests/unit/sync/io/test_class_bolt5x7.py b/tests/unit/sync/io/test_class_bolt5x7.py index 7d6523ee..95890c79 100644 --- a/tests/unit/sync/io/test_class_bolt5x7.py +++ b/tests/unit/sync/io/test_class_bolt5x7.py @@ -848,3 +848,25 @@ def _build_error_hierarchy_metadata(diag_records_metadata): if r is not ...: current_root["diagnostic_record"] = r return metadata + + +@pytest.mark.parametrize("ssr_hint", (True, False, None)) +@mark_sync_test +def test_ssr_enabled(ssr_hint, fake_socket_pair): + address = neo4j.Address(("127.0.0.1", 7687)) + sockets = fake_socket_pair( + address, + packer_cls=Bolt5x7.PACKER_CLS, + unpacker_cls=Bolt5x7.UNPACKER_CLS, + ) + meta = {"server": "Neo4j/4.3.4"} + if ssr_hint is not None: + meta["hints"] = {"ssr.enabled": ssr_hint} + sockets.server.send_message(b"\x70", meta) + sockets.server.send_message(b"\x70", {}) + connection = Bolt5x7( + address, sockets.client, PoolConfig.max_connection_lifetime + ) + assert connection.ssr_enabled is False + connection.hello() + assert connection.ssr_enabled is False diff --git a/tests/unit/sync/io/test_class_bolt5x8.py b/tests/unit/sync/io/test_class_bolt5x8.py index f172dc2a..25de8c2b 100644 --- a/tests/unit/sync/io/test_class_bolt5x8.py +++ b/tests/unit/sync/io/test_class_bolt5x8.py @@ -848,3 +848,25 @@ def _build_error_hierarchy_metadata(diag_records_metadata): if r is not ...: current_root["diagnostic_record"] = r return metadata + + +@pytest.mark.parametrize("ssr_hint", (True, False, None)) +@mark_sync_test +def test_ssr_enabled(ssr_hint, fake_socket_pair): + address = neo4j.Address(("127.0.0.1", 7687)) + sockets = fake_socket_pair( + address, + packer_cls=Bolt5x8.PACKER_CLS, + unpacker_cls=Bolt5x8.UNPACKER_CLS, + ) + meta = {"server": "Neo4j/4.3.4"} + if ssr_hint is not None: + meta["hints"] = {"ssr.enabled": ssr_hint} + sockets.server.send_message(b"\x70", meta) + sockets.server.send_message(b"\x70", {}) + connection = Bolt5x8( + address, sockets.client, PoolConfig.max_connection_lifetime + ) + assert connection.ssr_enabled is False + connection.hello() + assert connection.ssr_enabled is bool(ssr_hint) diff --git a/tests/unit/sync/test_home_db_cache.py b/tests/unit/sync/test_home_db_cache.py index 1c7aa977..cd656676 100644 --- a/tests/unit/sync/test_home_db_cache.py +++ b/tests/unit/sync/test_home_db_cache.py @@ -16,6 +16,7 @@ from __future__ import annotations +import math import time import typing as t from datetime import ( @@ -258,10 +259,15 @@ def get_default_cache(): repetitions = 5 scenario_timings = [] - default_max_size = get_default_cache()._max_size # Test assumes that by default the driver uses a home db cache only limited # by its size. - assert default_max_size + default_cache = get_default_cache() + default_max_size = default_cache._max_size + assert isinstance(default_max_size, int) + # If ttl ever get used, this test needs to be updated to also test pruning + # by TTL. + assert math.isinf(default_cache._ttl) and default_cache._ttl > 0 + for max_size, count in ( # no pruning needed (default_max_size * 10, default_max_size * 10), @@ -280,7 +286,8 @@ def get_default_cache(): scenario_timings.append(sum(rep_timings) / len(rep_timings)) # pruning shouldn't take more than 20 times the time of no pruning - # N.B., the pruning takes O(n * log(n)) where n is max_size. So to achieve - # this limit, either max_size needs to be sufficiently small or the pruning - # algorithm needs to be performant enough. + # N.B., the pruning takes O(n * log(n)) where n is max_size. By only + # pruning O(n * log(n)) elements, we get an amortized pruning overhead of + # O(1) (as long as max_size is small enough to be able to choose a positive + # pruning size). assert scenario_timings[1] <= 20 * scenario_timings[0] diff --git a/tests/unit/sync/work/test_result.py b/tests/unit/sync/work/test_result.py index 4f87da73..65249303 100644 --- a/tests/unit/sync/work/test_result.py +++ b/tests/unit/sync/work/test_result.py @@ -60,7 +60,10 @@ Neo4jWarning, ) -from ...._async_compat import mark_sync_test +from ...._async_compat import ( + mark_sync_test, + wrap_async, +) if t.TYPE_CHECKING: @@ -1424,23 +1427,21 @@ def test_notification_logging( assert caplog.messages[0] == expected_message -@pytest.mark.parametrize("cb", (True, False)) +@pytest.mark.parametrize( + "cb", + (True, False) if Util.is_async_code else (False,), +) @pytest.mark.parametrize("resolved_db", (..., None, "resolved_db")) @mark_sync_test def test_on_database_callback(cb, resolved_db): cb_calls = [] - if cb: - - def db_callback(db): - nonlocal cb_calls - cb_calls.append(db) - - else: + def db_callback(db): + nonlocal cb_calls + cb_calls.append(db) - def db_callback(db): - nonlocal cb_calls - cb_calls.append(db) + if cb: + db_callback = wrap_async(db_callback) run_meta = {} if resolved_db is not ...: diff --git a/tests/unit/sync/work/test_session.py b/tests/unit/sync/work/test_session.py index c081bb27..94543cde 100644 --- a/tests/unit/sync/work/test_session.py +++ b/tests/unit/sync/work/test_session.py @@ -435,12 +435,12 @@ def work(tx): assert call.kwargs["parameters"] == expected_params -@pytest.mark.parametrize("db", (None, "adb")[:1]) -@pytest.mark.parametrize("routing", (True, False)[:1]) +@pytest.mark.parametrize("db", (None, "adb")) +@pytest.mark.parametrize("routing", (True, False)) # no home db resolution when connected to Neo4j 4.3 or earlier -@pytest.mark.parametrize("home_db_gets_resolved", (True, False)[:1]) +@pytest.mark.parametrize("home_db_gets_resolved", (True, False)) @pytest.mark.parametrize( - "additional_session_bookmarks", (None, ["session", "bookmarks"])[:1] + "additional_session_bookmarks", (None, ["session", "bookmarks"]) ) @mark_sync_test def test_with_bookmark_manager( @@ -720,7 +720,10 @@ def work(_): @pytest.mark.parametrize("imp_user", (None, "imp_user")) @pytest.mark.parametrize( "auth", - (None, Auth(scheme="magic-auth", principal=None, credentials="tada")), + ( + None, + Auth(scheme="magic-auth", principal=None, credentials="tada"), + ), ) @mark_sync_test def test_uses_home_db_cache_when_expected( From 6b309729bda247ca37ebf2ef104ece7c04f2ba2a Mon Sep 17 00:00:00 2001 From: Robsdedude Date: Mon, 13 Jan 2025 13:56:49 +0100 Subject: [PATCH 22/26] Resolution fallback counting towards connection acquisition timeout --- src/neo4j/_async/io/_pool.py | 4 +++- src/neo4j/_async/work/workspace.py | 3 ++- src/neo4j/_sync/io/_pool.py | 4 +++- src/neo4j/_sync/work/workspace.py | 3 ++- 4 files changed, 10 insertions(+), 4 deletions(-) diff --git a/src/neo4j/_async/io/_pool.py b/src/neo4j/_async/io/_pool.py index be697ddb..73914f18 100644 --- a/src/neo4j/_async/io/_pool.py +++ b/src/neo4j/_async/io/_pool.py @@ -1137,7 +1137,9 @@ async def acquire( if access_mode not in {WRITE_ACCESS, READ_ACCESS}: # TODO: 6.0 - change this to be a ValueError raise ClientError(f"Non valid 'access_mode'; {access_mode}") - if not timeout: + if ( + isinstance(timeout, Deadline) and not timeout.original_timeout + ) or not timeout: # TODO: 6.0 - change this to be a ValueError raise ClientError( f"'timeout' must be a float larger than 0; {timeout}" diff --git a/src/neo4j/_async/work/workspace.py b/src/neo4j/_async/work/workspace.py index e88c4a7f..78d0be01 100644 --- a/src/neo4j/_async/work/workspace.py +++ b/src/neo4j/_async/work/workspace.py @@ -22,6 +22,7 @@ from ..._async_compat.util import AsyncUtil from ..._auth_management import to_auth_dict from ..._conf import WorkspaceConfig +from ..._deadline import Deadline from ..._meta import ( deprecation_warn, unclosed_resource_warn, @@ -191,7 +192,7 @@ async def _connect(self, access_mode, auth=None, **acquire_kwargs) -> None: ) acquire_kwargs_ = { "access_mode": access_mode, - "timeout": acquisition_timeout, + "timeout": Deadline(acquisition_timeout), "database": target_db, "bookmarks": await self._get_bookmarks(), "auth": acquire_auth, diff --git a/src/neo4j/_sync/io/_pool.py b/src/neo4j/_sync/io/_pool.py index c04c1165..a58cef93 100644 --- a/src/neo4j/_sync/io/_pool.py +++ b/src/neo4j/_sync/io/_pool.py @@ -1134,7 +1134,9 @@ def acquire( if access_mode not in {WRITE_ACCESS, READ_ACCESS}: # TODO: 6.0 - change this to be a ValueError raise ClientError(f"Non valid 'access_mode'; {access_mode}") - if not timeout: + if ( + isinstance(timeout, Deadline) and not timeout.original_timeout + ) or not timeout: # TODO: 6.0 - change this to be a ValueError raise ClientError( f"'timeout' must be a float larger than 0; {timeout}" diff --git a/src/neo4j/_sync/work/workspace.py b/src/neo4j/_sync/work/workspace.py index 39955001..6a32bea2 100644 --- a/src/neo4j/_sync/work/workspace.py +++ b/src/neo4j/_sync/work/workspace.py @@ -22,6 +22,7 @@ from ..._async_compat.util import Util from ..._auth_management import to_auth_dict from ..._conf import WorkspaceConfig +from ..._deadline import Deadline from ..._meta import ( deprecation_warn, unclosed_resource_warn, @@ -188,7 +189,7 @@ def _connect(self, access_mode, auth=None, **acquire_kwargs) -> None: ) acquire_kwargs_ = { "access_mode": access_mode, - "timeout": acquisition_timeout, + "timeout": Deadline(acquisition_timeout), "database": target_db, "bookmarks": self._get_bookmarks(), "auth": acquire_auth, From 439422129b769bb0fb2a15a4b8d25e6e8cf51d5a Mon Sep 17 00:00:00 2001 From: MaxAake <61233757+MaxAake@users.noreply.github.com> Date: Mon, 13 Jan 2025 14:50:11 +0100 Subject: [PATCH 23/26] testkit backend connection lifetime support --- testkitbackend/_async/requests.py | 1 + testkitbackend/test_config.json | 1 + 2 files changed, 2 insertions(+) diff --git a/testkitbackend/_async/requests.py b/testkitbackend/_async/requests.py index 08f00917..f08bb30d 100644 --- a/testkitbackend/_async/requests.py +++ b/testkitbackend/_async/requests.py @@ -190,6 +190,7 @@ async def new_driver(backend, data): ("maxTxRetryTimeMs", "max_transaction_retry_time"), ("connectionAcquisitionTimeoutMs", "connection_acquisition_timeout"), ("livenessCheckTimeoutMs", "liveness_check_timeout"), + ("maxConnectionLifetimeMs", "max_connection_lifetime"), ): if data.get(timeout_testkit) is not None: kwargs[timeout_driver] = data[timeout_testkit] / 1000 diff --git a/testkitbackend/test_config.json b/testkitbackend/test_config.json index 38d2b38c..580cad89 100644 --- a/testkitbackend/test_config.json +++ b/testkitbackend/test_config.json @@ -27,6 +27,7 @@ "Feature:API:Driver.VerifyAuthentication": true, "Feature:API:Driver.VerifyConnectivity": true, "Feature:API:Driver.SupportsSessionAuth": true, + "Feature:API:Driver:MaxConnectionLifetime": true, "Feature:API:Driver:NotificationsConfig": true, "Feature:API:Liveness.Check": true, "Feature:API:Result.List": true, From 89f3094b3a8c618a1c25d5d6040be3f3ffe377fa Mon Sep 17 00:00:00 2001 From: Robsdedude Date: Mon, 13 Jan 2025 16:11:35 +0100 Subject: [PATCH 24/26] Revert "Resolution fallback counting towards connection acquisition timeout" This reverts commit 6b309729bda247ca37ebf2ef104ece7c04f2ba2a. --- src/neo4j/_async/io/_pool.py | 4 +--- src/neo4j/_async/work/workspace.py | 3 +-- src/neo4j/_sync/io/_pool.py | 4 +--- src/neo4j/_sync/work/workspace.py | 3 +-- 4 files changed, 4 insertions(+), 10 deletions(-) diff --git a/src/neo4j/_async/io/_pool.py b/src/neo4j/_async/io/_pool.py index 73914f18..be697ddb 100644 --- a/src/neo4j/_async/io/_pool.py +++ b/src/neo4j/_async/io/_pool.py @@ -1137,9 +1137,7 @@ async def acquire( if access_mode not in {WRITE_ACCESS, READ_ACCESS}: # TODO: 6.0 - change this to be a ValueError raise ClientError(f"Non valid 'access_mode'; {access_mode}") - if ( - isinstance(timeout, Deadline) and not timeout.original_timeout - ) or not timeout: + if not timeout: # TODO: 6.0 - change this to be a ValueError raise ClientError( f"'timeout' must be a float larger than 0; {timeout}" diff --git a/src/neo4j/_async/work/workspace.py b/src/neo4j/_async/work/workspace.py index 78d0be01..e88c4a7f 100644 --- a/src/neo4j/_async/work/workspace.py +++ b/src/neo4j/_async/work/workspace.py @@ -22,7 +22,6 @@ from ..._async_compat.util import AsyncUtil from ..._auth_management import to_auth_dict from ..._conf import WorkspaceConfig -from ..._deadline import Deadline from ..._meta import ( deprecation_warn, unclosed_resource_warn, @@ -192,7 +191,7 @@ async def _connect(self, access_mode, auth=None, **acquire_kwargs) -> None: ) acquire_kwargs_ = { "access_mode": access_mode, - "timeout": Deadline(acquisition_timeout), + "timeout": acquisition_timeout, "database": target_db, "bookmarks": await self._get_bookmarks(), "auth": acquire_auth, diff --git a/src/neo4j/_sync/io/_pool.py b/src/neo4j/_sync/io/_pool.py index a58cef93..c04c1165 100644 --- a/src/neo4j/_sync/io/_pool.py +++ b/src/neo4j/_sync/io/_pool.py @@ -1134,9 +1134,7 @@ def acquire( if access_mode not in {WRITE_ACCESS, READ_ACCESS}: # TODO: 6.0 - change this to be a ValueError raise ClientError(f"Non valid 'access_mode'; {access_mode}") - if ( - isinstance(timeout, Deadline) and not timeout.original_timeout - ) or not timeout: + if not timeout: # TODO: 6.0 - change this to be a ValueError raise ClientError( f"'timeout' must be a float larger than 0; {timeout}" diff --git a/src/neo4j/_sync/work/workspace.py b/src/neo4j/_sync/work/workspace.py index 6a32bea2..39955001 100644 --- a/src/neo4j/_sync/work/workspace.py +++ b/src/neo4j/_sync/work/workspace.py @@ -22,7 +22,6 @@ from ..._async_compat.util import Util from ..._auth_management import to_auth_dict from ..._conf import WorkspaceConfig -from ..._deadline import Deadline from ..._meta import ( deprecation_warn, unclosed_resource_warn, @@ -189,7 +188,7 @@ def _connect(self, access_mode, auth=None, **acquire_kwargs) -> None: ) acquire_kwargs_ = { "access_mode": access_mode, - "timeout": Deadline(acquisition_timeout), + "timeout": acquisition_timeout, "database": target_db, "bookmarks": self._get_bookmarks(), "auth": acquire_auth, From b34c3d7ea2f82ebd0223536bb89d3aec0e1d01a7 Mon Sep 17 00:00:00 2001 From: Robsdedude Date: Mon, 13 Jan 2025 16:11:46 +0100 Subject: [PATCH 25/26] TestKit: skip test for unified acquisition timeout on db cache fallback --- testkitbackend/test_config.json | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/testkitbackend/test_config.json b/testkitbackend/test_config.json index 580cad89..0fe66cc5 100644 --- a/testkitbackend/test_config.json +++ b/testkitbackend/test_config.json @@ -15,7 +15,9 @@ "'neo4j.datatypes.test_temporal_types.TestDataTypes.test_date_time_cypher_created_tz_id'": "test_subtest_skips.tz_id", "stub\\.routing\\.test_routing_v[0-9x]+\\.RoutingV[0-9x]+\\.test_should_drop_connections_failing_liveness_check": - "Liveness check error handling is not (yet) unified: https://github.com/neo-technology/drivers-adr/pull/83" + "Liveness check error handling is not (yet) unified: https://github.com/neo-technology/drivers-adr/pull/83", + "'stub.homedb.test_homedb.TestHomeDbMixedCluster.test_connection_acquisition_timeout_during_fallback'": + "TODO: 6.0 - pending unification: connection acquisition timeout should count towards the total time spent waiting for a connection (including routing, home db resolution, ...)" }, "features": { "Feature:API:BookmarkManager": true, From 0f4a7471930c163ea8602c6b6ca8e8c348b24d75 Mon Sep 17 00:00:00 2001 From: Robsdedude Date: Tue, 14 Jan 2025 10:04:23 +0100 Subject: [PATCH 26/26] Optimization: don't use lock for home db cache SSR race check Only checking the acquired connection for SSR support is sufficient for correctness and saves checking the all connections in the pool (i.e., saves taking the lock on the pool). --- src/neo4j/_async/work/workspace.py | 8 +++----- src/neo4j/_sync/work/workspace.py | 8 +++----- 2 files changed, 6 insertions(+), 10 deletions(-) diff --git a/src/neo4j/_async/work/workspace.py b/src/neo4j/_async/work/workspace.py index e88c4a7f..6f0a08c7 100644 --- a/src/neo4j/_async/work/workspace.py +++ b/src/neo4j/_async/work/workspace.py @@ -203,12 +203,10 @@ async def _connect(self, access_mode, auth=None, **acquire_kwargs) -> None: if ( target_db.guessed and not self._pinned_database - and ( - not self._pool.ssr_enabled or not self._connection.ssr_enabled - ) + and not self._connection.ssr_enabled ): - # race condition: in the meantime the pool added a connection - # which does not support SSR. + # race condition: we now have created a connection which does not + # support SSR. # => we need to fall back to explicit home database resolution log.debug( "[#0000] _: detected ssr support race; " diff --git a/src/neo4j/_sync/work/workspace.py b/src/neo4j/_sync/work/workspace.py index 39955001..a85bdf8d 100644 --- a/src/neo4j/_sync/work/workspace.py +++ b/src/neo4j/_sync/work/workspace.py @@ -200,12 +200,10 @@ def _connect(self, access_mode, auth=None, **acquire_kwargs) -> None: if ( target_db.guessed and not self._pinned_database - and ( - not self._pool.ssr_enabled or not self._connection.ssr_enabled - ) + and not self._connection.ssr_enabled ): - # race condition: in the meantime the pool added a connection - # which does not support SSR. + # race condition: we now have created a connection which does not + # support SSR. # => we need to fall back to explicit home database resolution log.debug( "[#0000] _: detected ssr support race; "