diff --git a/docs/source/api.rst b/docs/source/api.rst index b5c133ea3..760cb48de 100644 --- a/docs/source/api.rst +++ b/docs/source/api.rst @@ -403,6 +403,7 @@ Additional configuration can be provided via the :class:`neo4j.Driver` construct + :ref:`trust-ref` + :ref:`ssl-context-ref` + :ref:`trusted-certificates-ref` ++ :ref:`client-certificate-ref` + :ref:`user-agent-ref` + :ref:`driver-notifications-min-severity-ref` + :ref:`driver-notifications-disabled-categories-ref` @@ -573,7 +574,8 @@ Specify how to determine the authenticity of encryption certificates provided by This setting is only available for URI schemes ``bolt://`` and ``neo4j://`` (:ref:`uri-ref`). -This setting does not have any effect if ``encrypted`` is set to ``False``. +This setting does not have any effect if ``encrypted`` is set to ``False`` or a +custom ``ssl_context`` is configured. :Type: ``neo4j.TRUST_SYSTEM_CA_SIGNED_CERTIFICATES``, ``neo4j.TRUST_ALL_CERTIFICATES`` @@ -605,7 +607,7 @@ Specify a custom SSL context to use for wrapping connections. This setting is only available for URI schemes ``bolt://`` and ``neo4j://`` (:ref:`uri-ref`). -If given, ``encrypted`` and ``trusted_certificates`` have no effect. +If given, ``encrypted``, ``trusted_certificates``, and ``client_certificate`` have no effect. .. warning:: This option may compromise your application's security if used improperly. @@ -632,13 +634,37 @@ custom ``ssl_context`` is configured. :Type: :class:`.TrustSystemCAs`, :class:`.TrustAll`, or :class:`.TrustCustomCAs` :Default: :const:`neo4j.TrustSystemCAs()` +.. versionadded:: 5.0 + .. autoclass:: neo4j.TrustSystemCAs .. autoclass:: neo4j.TrustAll .. autoclass:: neo4j.TrustCustomCAs -.. versionadded:: 5.0 + +.. _client-certificate-ref: + +``client_certificate`` +---------------------- +Specify a client certificate or certificate provider for mutual TLS (mTLS) authentication. + +This setting does not have any effect if ``encrypted`` is set to ``False`` +(and the URI scheme is ``bolt://`` or ``neo4j://``) or a custom ``ssl_context`` is configured. + +**This is a preview** (see :ref:`filter-warnings-ref`). +It might be changed without following the deprecation policy. +See also +https://github.com/neo4j/neo4j-python-driver/wiki/preview-features + +:Type: :class:`.ClientCertificate`, :class:`.ClientCertificateProvider` or :data:`None`. +:Default: :data:`None` + +.. versionadded:: 5.19 + +.. autoclass:: neo4j.auth_management.ClientCertificate + +.. autoclass:: neo4j.auth_management.ClientCertificateProvider .. _user-agent-ref: diff --git a/docs/source/async_api.rst b/docs/source/async_api.rst index ba06e8c25..8c666caae 100644 --- a/docs/source/async_api.rst +++ b/docs/source/async_api.rst @@ -381,7 +381,8 @@ Async Driver Configuration driver accepts * a sync as well as an async custom resolver function (see :ref:`async-resolver-ref`) - * as sync as well as an async auth token manager (see :class:`.AsyncAuthManager`). + * a sync as well as an async auth token manager (see :class:`.AsyncAuthManager`). + * an async client certificate provider (see :ref:`async-client-certificate-ref`). .. _async-resolver-ref: @@ -436,6 +437,28 @@ For example: :Default: :data:`None` +.. _async-client-certificate-ref: + +``client_certificate`` +---------------------- +Specify a client certificate or certificate provider for mutual TLS (mTLS) authentication. + +This setting does not have any effect if ``encrypted`` is set to ``False`` +(and the URI scheme is ``bolt://`` or ``neo4j://``) or a custom ``ssl_context`` is configured. + +**This is a preview** (see :ref:`filter-warnings-ref`). +It might be changed without following the deprecation policy. +See also +https://github.com/neo4j/neo4j-python-driver/wiki/preview-features + +:Type: :class:`.ClientCertificate`, :class:`.AsyncClientCertificateProvider` or :data:`None`. +:Default: :data:`None` + +.. versionadded:: 5.19 + +.. autoclass:: neo4j.auth_management.AsyncClientCertificateProvider + + Driver Object Lifetime ====================== diff --git a/src/neo4j/__init__.py b/src/neo4j/__init__.py index 70c6bcc0c..6cc645003 100644 --- a/src/neo4j/__init__.py +++ b/src/neo4j/__init__.py @@ -37,7 +37,6 @@ ) from ._conf import ( Config as _Config, - PoolConfig as _PoolConfig, SessionConfig as _SessionConfig, TrustAll, TrustCustomCAs, @@ -53,6 +52,7 @@ PreviewWarning, version as __version__, ) +from ._sync.config import PoolConfig as _PoolConfig from ._sync.driver import ( BoltDriver, Driver, diff --git a/src/neo4j/_api.py b/src/neo4j/_api.py index 97dfc25da..1582230d0 100644 --- a/src/neo4j/_api.py +++ b/src/neo4j/_api.py @@ -48,11 +48,11 @@ class NotificationMinimumSeverity(str, Enum): >>> NotificationMinimumSeverity.INFORMATION == "INFORMATION" True - .. versionadded:: 5.7 - .. seealso:: driver config :ref:`driver-notifications-min-severity-ref`, session config :ref:`session-notifications-min-severity-ref` + + .. versionadded:: 5.7 """ OFF = "OFF" @@ -111,9 +111,9 @@ class NotificationSeverity(str, Enum): # or severity_level == "UNKNOWN" log.debug("%r", notification) - .. versionadded:: 5.7 - .. seealso:: :attr:`SummaryNotification.severity_level` + + .. versionadded:: 5.7 """ WARNING = "WARNING" @@ -137,14 +137,14 @@ class NotificationDisabledCategory(str, Enum): >>> NotificationDisabledCategory.DEPRECATION == "DEPRECATION" True + .. seealso:: + driver config :ref:`driver-notifications-disabled-categories-ref`, + session config :ref:`session-notifications-disabled-categories-ref` + .. versionadded:: 5.7 .. versionchanged:: 5.14 Added categories :attr:`.SECURITY` and :attr:`.TOPOLOGY`. - - .. seealso:: - driver config :ref:`driver-notifications-disabled-categories-ref`, - session config :ref:`session-notifications-disabled-categories-ref` """ HINT = "HINT" @@ -188,12 +188,12 @@ class NotificationCategory(str, Enum): >>> NotificationCategory.UNKNOWN == "UNKNOWN" True + .. seealso:: :attr:`SummaryNotification.category` + .. versionadded:: 5.7 .. versionchanged:: 5.14 Added categories :attr:`.SECURITY` and :attr:`.TOPOLOGY`. - - .. seealso:: :attr:`SummaryNotification.category` """ HINT = "HINT" diff --git a/src/neo4j/_async/auth_management.py b/src/neo4j/_async/auth_management.py index 492e5eec2..baab82172 100644 --- a/src/neo4j/_async/auth_management.py +++ b/src/neo4j/_async/auth_management.py @@ -19,12 +19,18 @@ import typing as t from logging import getLogger -from .._async_compat.concurrency import AsyncLock +from .._async_compat.concurrency import ( + AsyncCooperativeLock, + AsyncLock, +) from .._auth_management import ( AsyncAuthManager, + AsyncClientCertificateProvider, + ClientCertificate, expiring_auth_has_expired, ExpiringAuth, ) +from .._meta import preview if t.TYPE_CHECKING: @@ -285,3 +291,127 @@ async def auth_provider(): "Neo.ClientError.Security.Unauthorized", )) return AsyncNeo4jAuthTokenManager(provider, handled_codes) + + +class _AsyncStaticClientCertificateProvider(AsyncClientCertificateProvider): + _cert: t.Optional[ClientCertificate] + + def __init__(self, cert: ClientCertificate) -> None: + self._cert = cert + + async def get_certificate(self) -> t.Optional[ClientCertificate]: + cert, self._cert = self._cert, None + return cert + + +@preview("Mutual TLS is a preview feature.") +class AsyncRotatingClientCertificateProvider(AsyncClientCertificateProvider): + """ + Implementation of a certificate provider that can rotate certificates. + + The provider will make the driver use the initial certificate for all + connections until the certificate is updated using the + :meth:`update_certificate` method. + From that point on, the new certificate will be used for all new + connections until :meth:`update_certificate` is called again and so on. + + **This is a preview** (see :ref:`filter-warnings-ref`). + It might be changed without following the deprecation policy. + See also + https://github.com/neo4j/neo4j-python-driver/wiki/preview-features + + Example:: + + from neo4j import AsyncGraphDatabase + from neo4j.auth_management import ( + ClientCertificate, + AsyncClientCertificateProviders, + ) + + + provider = AsyncClientCertificateProviders.rotating( + ClientCertificate( + certfile="path/to/certfile.pem", + keyfile="path/to/keyfile.pem", + password=lambda: "super_secret_password" + ) + ) + driver = AsyncGraphDatabase.driver( + # secure driver must be configured for client certificate + # to be used: (...+s[sc] scheme or encrypted=True) + "neo4j+s://example.com:7687", + # auth still required as before, unless server is configured to not + # use authentication + auth=("neo4j", "password"), + client_certificate=provider + ) + + # do work with the driver, until the certificate needs to be rotated + ... + + await provider.update_certificate( + ClientCertificate( + certfile="path/to/new/certfile.pem", + keyfile="path/to/new/keyfile.pem", + password=lambda: "new_super_secret_password" + ) + ) + + # do more work with the driver, until the certificate needs to be + # rotated again + ... + + :param initial_cert: The certificate to use initially. + + .. versionadded:: 5.19 + + """ + def __init__(self, initial_cert: ClientCertificate) -> None: + self._cert: t.Optional[ClientCertificate] = initial_cert + self._lock = AsyncCooperativeLock() + + async def get_certificate(self) -> t.Optional[ClientCertificate]: + async with self._lock: + cert, self._cert = self._cert, None + return cert + + async def update_certificate(self, cert: ClientCertificate) -> None: + """ + Update the certificate to use for new connections. + """ + async with self._lock: + self._cert = cert + + +class AsyncClientCertificateProviders: + """A collection of :class:`.AsyncClientCertificateProvider` factories. + + **This is a preview** (see :ref:`filter-warnings-ref`). + It might be changed without following the deprecation policy. + See also + https://github.com/neo4j/neo4j-python-driver/wiki/preview-features + + .. versionadded:: 5.19 + """ + @staticmethod + @preview("Mutual TLS is a preview feature.") + def static(cert: ClientCertificate) -> AsyncClientCertificateProvider: + """ + Create a static client certificate provider. + + The provider simply makes the driver use the given certificate for all + connections. + """ + return _AsyncStaticClientCertificateProvider(cert) + + @staticmethod + @preview("Mutual TLS is a preview feature.") + def rotating( + initial_cert: ClientCertificate + ) -> AsyncRotatingClientCertificateProvider: + """ + Create certificate provider that allows for rotating certificates. + + .. seealso:: :class:`.AsyncRotatingClientCertificateProvider` + """ + return AsyncRotatingClientCertificateProvider(initial_cert) diff --git a/src/neo4j/_async/config.py b/src/neo4j/_async/config.py new file mode 100644 index 000000000..1366f54cd --- /dev/null +++ b/src/neo4j/_async/config.py @@ -0,0 +1,187 @@ +# Copyright (c) "Neo4j" +# Neo4j Sweden AB [https://neo4j.com] +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +from __future__ import annotations + +import typing as t + +from .._async_compat.concurrency import AsyncLock +from .._conf import ( + _trust_to_trusted_certificates, + Config, + DeprecatedAlternative, + TrustAll, + TrustCustomCAs, + TrustSystemCAs, +) + + +if t.TYPE_CHECKING: + import ssl + + from .._auth_management import ClientCertificate + + +class AsyncPoolConfig(Config): + """ Connection pool configuration. + """ + + #: Max Connection Lifetime + max_connection_lifetime = 3600 # seconds + # The maximum duration the driver will keep a connection for before being removed from the pool. + + #: Timeout after which idle connections will be checked for liveness + #: before returned from the pool. + liveness_check_timeout = None + + #: Max Connection Pool Size + max_connection_pool_size = 100 + # The maximum total number of connections allowed, per host (i.e. cluster nodes), to be managed by the connection pool. + + #: Connection Timeout + connection_timeout = 30.0 # seconds + # The maximum amount of time to wait for a TCP connection to be established. + + #: Trust + trust = DeprecatedAlternative( + "trusted_certificates", _trust_to_trusted_certificates + ) + # Specify how to determine the authenticity of encryption certificates provided by the Neo4j instance on connection. + + #: Custom Resolver + resolver = None + # Custom resolver function, returning list of resolved addresses. + + #: Encrypted + encrypted = False + # Specify whether to use an encrypted connection between the driver and server. + + #: SSL Certificates to Trust + trusted_certificates = TrustSystemCAs() + # Specify how to determine the authenticity of encryption certificates + # provided by the Neo4j instance on connection. + # * `neo4j.TrustSystemCAs()`: Use system trust store. (default) + # * `neo4j.TrustAll()`: Trust any certificate. + # * `neo4j.TrustCustomCAs("", ...)`: + # Trust the specified certificate(s). + + #: Certificate to use for mTLS as 2nd authentication factor. + client_certificate = None + + #: Custom SSL context to use for wrapping sockets + ssl_context = None + # Use any custom SSL context to wrap sockets. + # Overwrites `trusted_certificates` and `encrypted`. + # The use of this option is strongly discouraged. + + #: User Agent (Python Driver Specific) + user_agent = None + # Specify the client agent name. + + #: Socket Keep Alive (Python and .NET Driver Specific) + keep_alive = True + # Specify whether TCP keep-alive should be enabled. + + #: Authentication provider + auth = None + + #: Lowest notification severity for the server to return + notifications_min_severity = None + + #: List of notification categories for the server to ignore + notifications_disabled_categories = None + + #: Opt-Out of telemetry collection + telemetry_disabled = False + + _ssl_context_cache: t.Optional[ssl.SSLContext] + _ssl_context_cache_lock: AsyncLock + + def __init__(self, *args, **kwargs) -> None: + super().__init__(*args, **kwargs) + self._ssl_context_cache = None + self._ssl_context_cache_lock = AsyncLock() + + async def get_ssl_context(self) -> t.Optional[ssl.SSLContext]: + if self.ssl_context is not None: + return self.ssl_context + + if not self.encrypted: + return None + + client_cert: t.Optional[ClientCertificate] = None + + # try to serve the cached ssl context + async with self._ssl_context_cache_lock: + if self._ssl_context_cache is not None: + if self.client_certificate is None: + return self._ssl_context_cache + client_cert = await self.client_certificate.get_certificate() + if client_cert is None: + return self._ssl_context_cache + elif self.client_certificate is not None: + client_cert = await self.client_certificate.get_certificate() + + import ssl + + # SSL stands for Secure Sockets Layer and was originally created by + # Netscape. + # SSLv2 and SSLv3 are the 2 versions of this protocol (SSLv1 was + # never publicly released). + # After SSLv3, SSL was renamed to TLS. + # TLS stands for Transport Layer Security and started with TLSv1.0 + # which is an upgraded version of SSLv3. + # SSLv2 - (Disabled) + # SSLv3 - (Disabled) + # TLS 1.0 - Released in 1999, published as RFC 2246. (Disabled) + # TLS 1.1 - Released in 2006, published as RFC 4346. (Disabled) + # TLS 1.2 - Released in 2008, published as RFC 5246. + # https://docs.python.org/3.7/library/ssl.html#ssl.PROTOCOL_TLS_CLIENT + ssl_context = ssl.SSLContext(ssl.PROTOCOL_TLS_CLIENT) + + # For recommended security options see + # https://docs.python.org/3.10/library/ssl.html#protocol-versions + ssl_context.minimum_version = ssl.TLSVersion.TLSv1_2 + + if isinstance(self.trusted_certificates, TrustAll): + # trust any certificate + ssl_context.check_hostname = False + # https://docs.python.org/3.7/library/ssl.html#ssl.CERT_NONE + ssl_context.verify_mode = ssl.CERT_NONE + elif isinstance(self.trusted_certificates, TrustCustomCAs): + # trust the specified certificate(s) + ssl_context.check_hostname = True + ssl_context.verify_mode = ssl.CERT_REQUIRED + for cert in self.trusted_certificates.certs: + ssl_context.load_verify_locations(cert) + else: + # default + # trust system CA certificates + ssl_context.check_hostname = True + ssl_context.verify_mode = ssl.CERT_REQUIRED + # Must be load_default_certs, not set_default_verify_paths to + # work on Windows with system CAs. + ssl_context.load_default_certs() + + if client_cert is not None: + ssl_context.load_cert_chain( + client_cert.certfile, + keyfile=client_cert.keyfile, + password=client_cert.password, + ) + + self._ssl_context_cache = ssl_context + return ssl_context diff --git a/src/neo4j/_async/driver.py b/src/neo4j/_async/driver.py index 57312ac9a..f59a2d79e 100644 --- a/src/neo4j/_async/driver.py +++ b/src/neo4j/_async/driver.py @@ -37,7 +37,6 @@ from .._conf import ( Config, ConfigurationError, - PoolConfig, SessionConfig, TrustAll, TrustStore, @@ -46,6 +45,7 @@ from .._meta import ( deprecation_warn, experimental_warn, + preview_warn, unclosed_resource_warn, ) from .._work import ( @@ -80,13 +80,17 @@ from ..auth_management import ( AsyncAuthManager, AsyncAuthManagers, + AsyncClientCertificateProvider, + ClientCertificate, ) from ..exceptions import Neo4jError +from .auth_management import _AsyncStaticClientCertificateProvider from .bookmark_manager import ( AsyncNeo4jBookmarkManager, TBmConsumer as _TBmConsumer, TBmSupplier as _TBmSupplier, ) +from .config import AsyncPoolConfig from .work import ( AsyncManagedTransaction, AsyncResult, @@ -141,7 +145,10 @@ def driver( ] = ..., encrypted: bool = ..., trusted_certificates: TrustStore = ..., - ssl_context: ssl.SSLContext = ..., + client_certificate: t.Union[ + ClientCertificate, AsyncClientCertificateProvider, None + ] = ..., + ssl_context: t.Optional[ssl.SSLContext] = ..., user_agent: str = ..., keep_alive: bool = ..., notifications_min_severity: t.Optional[ @@ -194,6 +201,16 @@ def driver( auth = AsyncAuthManagers.static(auth) config["auth"] = auth + client_certificate = config.get("client_certificate") + if isinstance(client_certificate, ClientCertificate): + # using internal class until public factory is GA: + # AsyncClientCertificateProviders.static + config["client_certificate"] = \ + _AsyncStaticClientCertificateProvider(client_certificate) + if client_certificate is not None: + preview_warn("Mutual TLS is a preview feature.", + stack_level=2) + # TODO: 6.0 - remove "trust" config option if "trust" in config.keys(): if config["trust"] not in ( @@ -221,11 +238,12 @@ def driver( ) ) - if (security_type in [SECURITY_TYPE_SELF_SIGNED_CERTIFICATE, SECURITY_TYPE_SECURE] - and ("encrypted" in config.keys() - or "trust" in config.keys() - or "trusted_certificates" in config.keys() - or "ssl_context" in config.keys())): + if (security_type in (SECURITY_TYPE_SELF_SIGNED_CERTIFICATE, + SECURITY_TYPE_SECURE) + and ("encrypted" in config + or "trust" in config + or "trusted_certificates" in config + or "ssl_context" in config)): # TODO: 6.0 - remove "trust" from error message raise ConfigurationError( @@ -1253,7 +1271,7 @@ def open(cls, target, **config): """ from .io import AsyncBoltPool address = cls.parse_target(target) - pool_config, default_workspace_config = Config.consume_chain(config, PoolConfig, WorkspaceConfig) + pool_config, default_workspace_config = Config.consume_chain(config, AsyncPoolConfig, WorkspaceConfig) pool = AsyncBoltPool.open(address, pool_config=pool_config, workspace_config=default_workspace_config) return cls(pool, default_workspace_config) @@ -1278,7 +1296,7 @@ class AsyncNeo4jDriver(_Routing, AsyncDriver): def open(cls, *targets, routing_context=None, **config): from .io import AsyncNeo4jPool addresses = cls.parse_targets(*targets) - pool_config, default_workspace_config = Config.consume_chain(config, PoolConfig, WorkspaceConfig) + pool_config, default_workspace_config = Config.consume_chain(config, AsyncPoolConfig, WorkspaceConfig) pool = AsyncNeo4jPool.open(*addresses, routing_context=routing_context, pool_config=pool_config, workspace_config=default_workspace_config) return cls(pool, default_workspace_config) diff --git a/src/neo4j/_async/io/_bolt.py b/src/neo4j/_async/io/_bolt.py index 2efbc3a8c..b36617242 100644 --- a/src/neo4j/_async/io/_bolt.py +++ b/src/neo4j/_async/io/_bolt.py @@ -28,13 +28,13 @@ from ..._async_compat.util import AsyncUtil from ..._codec.hydration import v1 as hydration_v1 from ..._codec.packstream import v1 as packstream_v1 -from ..._conf import PoolConfig from ..._deadline import Deadline from ..._exceptions import ( BoltError, BoltHandshakeError, ) from ..._meta import USER_AGENT +from ..._sync.config import PoolConfig from ...addressing import ResolvedAddress from ...api import ( Auth, @@ -49,6 +49,7 @@ ServiceUnavailable, SessionExpired, ) +from ..config import AsyncPoolConfig from ._common import ( AsyncInbox, AsyncOutbox, @@ -392,7 +393,7 @@ async def open( """ if pool_config is None: - pool_config = PoolConfig() + pool_config = AsyncPoolConfig() if deadline is None: deadline = Deadline(None) @@ -402,7 +403,7 @@ async def open( tcp_timeout=pool_config.connection_timeout, deadline=deadline, custom_resolver=pool_config.resolver, - ssl_context=pool_config.get_ssl_context(), + ssl_context=await pool_config.get_ssl_context(), keep_alive=pool_config.keep_alive, ) diff --git a/src/neo4j/_async/io/_pool.py b/src/neo4j/_async/io/_pool.py index 1e51191b8..68b206d76 100644 --- a/src/neo4j/_async/io/_pool.py +++ b/src/neo4j/_async/io/_pool.py @@ -36,10 +36,7 @@ ) from ..._async_compat.network import AsyncNetworkUtil from ..._async_compat.util import AsyncUtil -from ..._conf import ( - PoolConfig, - WorkspaceConfig, -) +from ..._conf import WorkspaceConfig from ..._deadline import ( connection_deadline, Deadline, @@ -64,6 +61,7 @@ SessionExpired, WriteServiceUnavailable, ) +from ..config import AsyncPoolConfig from ._bolt import AsyncBolt @@ -83,7 +81,7 @@ class AsyncIOPool(abc.ABC): def __init__(self, opener, pool_config, workspace_config): assert callable(opener) - assert isinstance(pool_config, PoolConfig) + assert isinstance(pool_config, AsyncPoolConfig) assert isinstance(workspace_config, WorkspaceConfig) self.opener = opener diff --git a/src/neo4j/_async/work/result.py b/src/neo4j/_async/work/result.py index 4f172a0c3..00b01d81a 100644 --- a/src/neo4j/_async/work/result.py +++ b/src/neo4j/_async/work/result.py @@ -590,10 +590,10 @@ async def value( was obtained has been closed or the Result has been explicitly consumed. + .. seealso:: :meth:`.Record.value` + .. versionchanged:: 5.0 Can raise :exc:`.ResultConsumedError`. - - .. seealso:: :meth:`.Record.value` """ return [record.value(key, default) async for record in self] @@ -611,10 +611,10 @@ async def values( was obtained has been closed or the Result has been explicitly consumed. + .. seealso:: :meth:`.Record.values` + .. versionchanged:: 5.0 Can raise :exc:`.ResultConsumedError`. - - .. seealso:: :meth:`.Record.values` """ return [record.values(*keys) async for record in self] @@ -640,10 +640,10 @@ async def data(self, *keys: _TResultKey) -> t.List[t.Dict[str, t.Any]]: was obtained has been closed or the Result has been explicitly consumed. + .. seealso:: :meth:`.Record.data` + .. versionchanged:: 5.0 Can raise :exc:`.ResultConsumedError`. - - .. seealso:: :meth:`.Record.data` """ return [record.data(*keys) async for record in self] diff --git a/src/neo4j/_async_compat/concurrency.py b/src/neo4j/_async_compat/concurrency.py index ddf6e7b3e..942c4319e 100644 --- a/src/neo4j/_async_compat/concurrency.py +++ b/src/neo4j/_async_compat/concurrency.py @@ -171,10 +171,10 @@ async def __aexit__(self, t, v, tb): class AsyncCooperativeLock: """Lock placeholder for asyncio Python when working fully cooperatively. - This lock doesn't do anything in async Python. It's threaded counterpart, + This lock doesn't do anything in async Python. Its threaded counterpart, however, is an ordinary :class:`threading.Lock`. The AsyncCooperativeLock only works if there is no await being used - while the lock is acquired. + while the lock is held. """ def __init__(self): diff --git a/src/neo4j/_auth_management.py b/src/neo4j/_auth_management.py index 65a623697..8552c906c 100644 --- a/src/neo4j/_auth_management.py +++ b/src/neo4j/_auth_management.py @@ -21,11 +21,18 @@ import typing as t from dataclasses import dataclass +from ._meta import preview from .exceptions import Neo4jError if t.TYPE_CHECKING: + from os import PathLike + + from typing_extensions import Protocol as _Protocol + from .api import _TAuth +else: + _Protocol = object @dataclass @@ -164,7 +171,7 @@ def handle_security_exception( ... -class AsyncAuthManager(metaclass=abc.ABCMeta): +class AsyncAuthManager(_Protocol, metaclass=abc.ABCMeta): """Async version of :class:`.AuthManager`. .. seealso:: :class:`.AuthManager` @@ -193,3 +200,93 @@ async def handle_security_exception( .. seealso:: :meth:`.AuthManager.handle_security_exception` """ ... + + +@preview("Mutual TLS is a preview feature.") +@dataclass +class ClientCertificate: + """ + Simple data class to hold client certificate information. + + The attributes are the same as the arguments to + :meth:`ssl.SSLContext.load_cert_chain()`. + + .. versionadded:: 5.19 + """ + certfile: t.Union[str, bytes, PathLike[str], PathLike[bytes]] + keyfile: t.Union[str, bytes, PathLike[str], PathLike[bytes], None] = None + password: t.Union[ + t.Callable[[], t.Union[str | bytes]], + str, + bytes, + None + ] = None + + +class ClientCertificateProvider(_Protocol, metaclass=abc.ABCMeta): + """ + Provides a client certificate to the driver for mutual TLS. + + The package provides some default implementations of this class in + :class:`.AsyncClientCertificateProviders` for convenience. + + The driver will call :meth:`.get_certificate` to check if the client wants + the driver to use as new certificate for mutual TLS. + + The certificate is only used as a second factor for authenticating the + client. + The DBMS user still needs to authenticate with an authentication token. + + Note that the work done in the methods of this interface count towards the + connection acquisition affected by the respective timeout setting + :ref:`connection-acquisition-timeout-ref`. + Should fetching the certificate be particularly slow, it might be necessary + to increase the timeout. + + .. warning:: + + The provider **must not** interact with the driver in any way as this + can cause deadlocks and undefined behaviour. + + .. versionadded:: 5.19 + """ + + @abc.abstractmethod + def get_certificate(self) -> t.Optional[ClientCertificate]: + """ + Return the new certificate (if present) to use for new connections. + + If no new certificate is available, return :data:`None`. + This will make the driver continue using the current certificate. + + Note that a new certificate will only be used for new connections. + Already established connections will continue using the old + certificate as TLS is established during connection setup. + + :returns: The new certificate to use for new connections. + """ + ... + + +class AsyncClientCertificateProvider(_Protocol, metaclass=abc.ABCMeta): + """ + Async version of :class:`.ClientCertificateProvider`. + + **This is a preview** (see :ref:`filter-warnings-ref`). + It might be changed without following the deprecation policy. + See also + https://github.com/neo4j/neo4j-python-driver/wiki/preview-features + + .. seealso:: :class:`.ClientCertificateProvider` + + .. versionadded:: 5.19 + """ + + @abc.abstractmethod + async def get_certificate(self) -> t.Optional[ClientCertificate]: + """ + Return the new certificate (if present) to use for new connections. + + .. seealso:: :meth:`.ClientCertificateProvider.get_certificate` + """ + ... diff --git a/src/neo4j/_conf.py b/src/neo4j/_conf.py index 91fac0af3..330a4fc8b 100644 --- a/src/neo4j/_conf.py +++ b/src/neo4j/_conf.py @@ -16,6 +16,7 @@ from __future__ import annotations +import typing as t from abc import ABCMeta from collections.abc import Mapping @@ -347,123 +348,6 @@ def _trust_to_trusted_certificates(pool_config, trust): pool_config.trusted_certificates = TrustAll() -class PoolConfig(Config): - """ Connection pool configuration. - """ - - #: Max Connection Lifetime - max_connection_lifetime = 3600 # seconds - # The maximum duration the driver will keep a connection for before being removed from the pool. - - #: Timeout after which idle connections will be checked for liveness - #: before returned from the pool. - liveness_check_timeout = None - - #: Max Connection Pool Size - max_connection_pool_size = 100 - # The maximum total number of connections allowed, per host (i.e. cluster nodes), to be managed by the connection pool. - - #: Connection Timeout - connection_timeout = 30.0 # seconds - # The maximum amount of time to wait for a TCP connection to be established. - - #: Trust - trust = DeprecatedAlternative( - "trusted_certificates", _trust_to_trusted_certificates - ) - # Specify how to determine the authenticity of encryption certificates provided by the Neo4j instance on connection. - - #: Custom Resolver - resolver = None - # Custom resolver function, returning list of resolved addresses. - - #: Encrypted - encrypted = False - # Specify whether to use an encrypted connection between the driver and server. - - #: SSL Certificates to Trust - trusted_certificates = TrustSystemCAs() - # Specify how to determine the authenticity of encryption certificates - # provided by the Neo4j instance on connection. - # * `neo4j.TrustSystemCAs()`: Use system trust store. (default) - # * `neo4j.TrustAll()`: Trust any certificate. - # * `neo4j.TrustCustomCAs("", ...)`: - # Trust the specified certificate(s). - - #: Custom SSL context to use for wrapping sockets - ssl_context = None - # Use any custom SSL context to wrap sockets. - # Overwrites `trusted_certificates` and `encrypted`. - # The use of this option is strongly discouraged. - - #: User Agent (Python Driver Specific) - user_agent = None - # Specify the client agent name. - - #: Socket Keep Alive (Python and .NET Driver Specific) - keep_alive = True - # Specify whether TCP keep-alive should be enabled. - - #: Authentication provider - auth = None - - #: Lowest notification severity for the server to return - notifications_min_severity = None - - #: List of notification categories for the server to ignore - notifications_disabled_categories = None - - #: Opt-Out of telemetry collection - telemetry_disabled = False - - def get_ssl_context(self): - if self.ssl_context is not None: - return self.ssl_context - - if not self.encrypted: - return None - - import ssl - - # SSL stands for Secure Sockets Layer and was originally created by Netscape. - # SSLv2 and SSLv3 are the 2 versions of this protocol (SSLv1 was never publicly released). - # After SSLv3, SSL was renamed to TLS. - # TLS stands for Transport Layer Security and started with TLSv1.0 which is an upgraded version of SSLv3. - # SSLv2 - (Disabled) - # SSLv3 - (Disabled) - # TLS 1.0 - Released in 1999, published as RFC 2246. (Disabled) - # TLS 1.1 - Released in 2006, published as RFC 4346. (Disabled) - # TLS 1.2 - Released in 2008, published as RFC 5246. - # https://docs.python.org/3.7/library/ssl.html#ssl.PROTOCOL_TLS_CLIENT - ssl_context = ssl.SSLContext(ssl.PROTOCOL_TLS_CLIENT) - - # For recommended security options see - # https://docs.python.org/3.10/library/ssl.html#protocol-versions - ssl_context.minimum_version = ssl.TLSVersion.TLSv1_2 - - if isinstance(self.trusted_certificates, TrustAll): - # trust any certificate - ssl_context.check_hostname = False - # https://docs.python.org/3.7/library/ssl.html#ssl.CERT_NONE - ssl_context.verify_mode = ssl.CERT_NONE - elif isinstance(self.trusted_certificates, TrustCustomCAs): - # trust the specified certificate(s) - ssl_context.check_hostname = True - ssl_context.verify_mode = ssl.CERT_REQUIRED - for cert in self.trusted_certificates.certs: - ssl_context.load_verify_locations(cert) - else: - # default - # trust system CA certificates - ssl_context.check_hostname = True - ssl_context.verify_mode = ssl.CERT_REQUIRED - # Must be load_default_certs, not set_default_verify_paths to - # work on Windows with system CAs. - ssl_context.load_default_certs() - - return ssl_context - - class WorkspaceConfig(Config): """ WorkSpace configuration. """ diff --git a/src/neo4j/_sync/auth_management.py b/src/neo4j/_sync/auth_management.py index f075eac8b..dcae6c351 100644 --- a/src/neo4j/_sync/auth_management.py +++ b/src/neo4j/_sync/auth_management.py @@ -19,12 +19,18 @@ import typing as t from logging import getLogger -from .._async_compat.concurrency import Lock +from .._async_compat.concurrency import ( + CooperativeLock, + Lock, +) from .._auth_management import ( AuthManager, + ClientCertificate, + ClientCertificateProvider, expiring_auth_has_expired, ExpiringAuth, ) +from .._meta import preview if t.TYPE_CHECKING: @@ -285,3 +291,127 @@ def auth_provider(): "Neo.ClientError.Security.Unauthorized", )) return Neo4jAuthTokenManager(provider, handled_codes) + + +class _StaticClientCertificateProvider(ClientCertificateProvider): + _cert: t.Optional[ClientCertificate] + + def __init__(self, cert: ClientCertificate) -> None: + self._cert = cert + + def get_certificate(self) -> t.Optional[ClientCertificate]: + cert, self._cert = self._cert, None + return cert + + +@preview("Mutual TLS is a preview feature.") +class RotatingClientCertificateProvider(ClientCertificateProvider): + """ + Implementation of a certificate provider that can rotate certificates. + + The provider will make the driver use the initial certificate for all + connections until the certificate is updated using the + :meth:`update_certificate` method. + From that point on, the new certificate will be used for all new + connections until :meth:`update_certificate` is called again and so on. + + **This is a preview** (see :ref:`filter-warnings-ref`). + It might be changed without following the deprecation policy. + See also + https://github.com/neo4j/neo4j-python-driver/wiki/preview-features + + Example:: + + from neo4j import GraphDatabase + from neo4j.auth_management import ( + ClientCertificate, + ClientCertificateProviders, + ) + + + provider = ClientCertificateProviders.rotating( + ClientCertificate( + certfile="path/to/certfile.pem", + keyfile="path/to/keyfile.pem", + password=lambda: "super_secret_password" + ) + ) + driver = GraphDatabase.driver( + # secure driver must be configured for client certificate + # to be used: (...+s[sc] scheme or encrypted=True) + "neo4j+s://example.com:7687", + # auth still required as before, unless server is configured to not + # use authentication + auth=("neo4j", "password"), + client_certificate=provider + ) + + # do work with the driver, until the certificate needs to be rotated + ... + + provider.update_certificate( + ClientCertificate( + certfile="path/to/new/certfile.pem", + keyfile="path/to/new/keyfile.pem", + password=lambda: "new_super_secret_password" + ) + ) + + # do more work with the driver, until the certificate needs to be + # rotated again + ... + + :param initial_cert: The certificate to use initially. + + .. versionadded:: 5.19 + + """ + def __init__(self, initial_cert: ClientCertificate) -> None: + self._cert: t.Optional[ClientCertificate] = initial_cert + self._lock = CooperativeLock() + + def get_certificate(self) -> t.Optional[ClientCertificate]: + with self._lock: + cert, self._cert = self._cert, None + return cert + + def update_certificate(self, cert: ClientCertificate) -> None: + """ + Update the certificate to use for new connections. + """ + with self._lock: + self._cert = cert + + +class ClientCertificateProviders: + """A collection of :class:`.ClientCertificateProvider` factories. + + **This is a preview** (see :ref:`filter-warnings-ref`). + It might be changed without following the deprecation policy. + See also + https://github.com/neo4j/neo4j-python-driver/wiki/preview-features + + .. versionadded:: 5.19 + """ + @staticmethod + @preview("Mutual TLS is a preview feature.") + def static(cert: ClientCertificate) -> ClientCertificateProvider: + """ + Create a static client certificate provider. + + The provider simply makes the driver use the given certificate for all + connections. + """ + return _StaticClientCertificateProvider(cert) + + @staticmethod + @preview("Mutual TLS is a preview feature.") + def rotating( + initial_cert: ClientCertificate + ) -> RotatingClientCertificateProvider: + """ + Create certificate provider that allows for rotating certificates. + + .. seealso:: :class:`.RotatingClientCertificateProvider` + """ + return RotatingClientCertificateProvider(initial_cert) diff --git a/src/neo4j/_sync/config.py b/src/neo4j/_sync/config.py new file mode 100644 index 000000000..20e0daa06 --- /dev/null +++ b/src/neo4j/_sync/config.py @@ -0,0 +1,187 @@ +# Copyright (c) "Neo4j" +# Neo4j Sweden AB [https://neo4j.com] +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +from __future__ import annotations + +import typing as t + +from .._async_compat.concurrency import Lock +from .._conf import ( + _trust_to_trusted_certificates, + Config, + DeprecatedAlternative, + TrustAll, + TrustCustomCAs, + TrustSystemCAs, +) + + +if t.TYPE_CHECKING: + import ssl + + from .._auth_management import ClientCertificate + + +class PoolConfig(Config): + """ Connection pool configuration. + """ + + #: Max Connection Lifetime + max_connection_lifetime = 3600 # seconds + # The maximum duration the driver will keep a connection for before being removed from the pool. + + #: Timeout after which idle connections will be checked for liveness + #: before returned from the pool. + liveness_check_timeout = None + + #: Max Connection Pool Size + max_connection_pool_size = 100 + # The maximum total number of connections allowed, per host (i.e. cluster nodes), to be managed by the connection pool. + + #: Connection Timeout + connection_timeout = 30.0 # seconds + # The maximum amount of time to wait for a TCP connection to be established. + + #: Trust + trust = DeprecatedAlternative( + "trusted_certificates", _trust_to_trusted_certificates + ) + # Specify how to determine the authenticity of encryption certificates provided by the Neo4j instance on connection. + + #: Custom Resolver + resolver = None + # Custom resolver function, returning list of resolved addresses. + + #: Encrypted + encrypted = False + # Specify whether to use an encrypted connection between the driver and server. + + #: SSL Certificates to Trust + trusted_certificates = TrustSystemCAs() + # Specify how to determine the authenticity of encryption certificates + # provided by the Neo4j instance on connection. + # * `neo4j.TrustSystemCAs()`: Use system trust store. (default) + # * `neo4j.TrustAll()`: Trust any certificate. + # * `neo4j.TrustCustomCAs("", ...)`: + # Trust the specified certificate(s). + + #: Certificate to use for mTLS as 2nd authentication factor. + client_certificate = None + + #: Custom SSL context to use for wrapping sockets + ssl_context = None + # Use any custom SSL context to wrap sockets. + # Overwrites `trusted_certificates` and `encrypted`. + # The use of this option is strongly discouraged. + + #: User Agent (Python Driver Specific) + user_agent = None + # Specify the client agent name. + + #: Socket Keep Alive (Python and .NET Driver Specific) + keep_alive = True + # Specify whether TCP keep-alive should be enabled. + + #: Authentication provider + auth = None + + #: Lowest notification severity for the server to return + notifications_min_severity = None + + #: List of notification categories for the server to ignore + notifications_disabled_categories = None + + #: Opt-Out of telemetry collection + telemetry_disabled = False + + _ssl_context_cache: t.Optional[ssl.SSLContext] + _ssl_context_cache_lock: Lock + + def __init__(self, *args, **kwargs) -> None: + super().__init__(*args, **kwargs) + self._ssl_context_cache = None + self._ssl_context_cache_lock = Lock() + + def get_ssl_context(self) -> t.Optional[ssl.SSLContext]: + if self.ssl_context is not None: + return self.ssl_context + + if not self.encrypted: + return None + + client_cert: t.Optional[ClientCertificate] = None + + # try to serve the cached ssl context + with self._ssl_context_cache_lock: + if self._ssl_context_cache is not None: + if self.client_certificate is None: + return self._ssl_context_cache + client_cert = self.client_certificate.get_certificate() + if client_cert is None: + return self._ssl_context_cache + elif self.client_certificate is not None: + client_cert = self.client_certificate.get_certificate() + + import ssl + + # SSL stands for Secure Sockets Layer and was originally created by + # Netscape. + # SSLv2 and SSLv3 are the 2 versions of this protocol (SSLv1 was + # never publicly released). + # After SSLv3, SSL was renamed to TLS. + # TLS stands for Transport Layer Security and started with TLSv1.0 + # which is an upgraded version of SSLv3. + # SSLv2 - (Disabled) + # SSLv3 - (Disabled) + # TLS 1.0 - Released in 1999, published as RFC 2246. (Disabled) + # TLS 1.1 - Released in 2006, published as RFC 4346. (Disabled) + # TLS 1.2 - Released in 2008, published as RFC 5246. + # https://docs.python.org/3.7/library/ssl.html#ssl.PROTOCOL_TLS_CLIENT + ssl_context = ssl.SSLContext(ssl.PROTOCOL_TLS_CLIENT) + + # For recommended security options see + # https://docs.python.org/3.10/library/ssl.html#protocol-versions + ssl_context.minimum_version = ssl.TLSVersion.TLSv1_2 + + if isinstance(self.trusted_certificates, TrustAll): + # trust any certificate + ssl_context.check_hostname = False + # https://docs.python.org/3.7/library/ssl.html#ssl.CERT_NONE + ssl_context.verify_mode = ssl.CERT_NONE + elif isinstance(self.trusted_certificates, TrustCustomCAs): + # trust the specified certificate(s) + ssl_context.check_hostname = True + ssl_context.verify_mode = ssl.CERT_REQUIRED + for cert in self.trusted_certificates.certs: + ssl_context.load_verify_locations(cert) + else: + # default + # trust system CA certificates + ssl_context.check_hostname = True + ssl_context.verify_mode = ssl.CERT_REQUIRED + # Must be load_default_certs, not set_default_verify_paths to + # work on Windows with system CAs. + ssl_context.load_default_certs() + + if client_cert is not None: + ssl_context.load_cert_chain( + client_cert.certfile, + keyfile=client_cert.keyfile, + password=client_cert.password, + ) + + self._ssl_context_cache = ssl_context + return ssl_context diff --git a/src/neo4j/_sync/driver.py b/src/neo4j/_sync/driver.py index 3db702016..2b8a916c5 100644 --- a/src/neo4j/_sync/driver.py +++ b/src/neo4j/_sync/driver.py @@ -37,7 +37,6 @@ from .._conf import ( Config, ConfigurationError, - PoolConfig, SessionConfig, TrustAll, TrustStore, @@ -46,6 +45,7 @@ from .._meta import ( deprecation_warn, experimental_warn, + preview_warn, unclosed_resource_warn, ) from .._work import ( @@ -79,13 +79,17 @@ from ..auth_management import ( AuthManager, AuthManagers, + ClientCertificate, + ClientCertificateProvider, ) from ..exceptions import Neo4jError +from .auth_management import _StaticClientCertificateProvider from .bookmark_manager import ( Neo4jBookmarkManager, TBmConsumer as _TBmConsumer, TBmSupplier as _TBmSupplier, ) +from .config import PoolConfig from .work import ( ManagedTransaction, Result, @@ -140,7 +144,10 @@ def driver( ] = ..., encrypted: bool = ..., trusted_certificates: TrustStore = ..., - ssl_context: ssl.SSLContext = ..., + client_certificate: t.Union[ + ClientCertificate, ClientCertificateProvider, None + ] = ..., + ssl_context: t.Optional[ssl.SSLContext] = ..., user_agent: str = ..., keep_alive: bool = ..., notifications_min_severity: t.Optional[ @@ -193,6 +200,16 @@ def driver( auth = AuthManagers.static(auth) config["auth"] = auth + client_certificate = config.get("client_certificate") + if isinstance(client_certificate, ClientCertificate): + # using internal class until public factory is GA: + # AsyncClientCertificateProviders.static + config["client_certificate"] = \ + _StaticClientCertificateProvider(client_certificate) + if client_certificate is not None: + preview_warn("Mutual TLS is a preview feature.", + stack_level=2) + # TODO: 6.0 - remove "trust" config option if "trust" in config.keys(): if config["trust"] not in ( @@ -220,11 +237,12 @@ def driver( ) ) - if (security_type in [SECURITY_TYPE_SELF_SIGNED_CERTIFICATE, SECURITY_TYPE_SECURE] - and ("encrypted" in config.keys() - or "trust" in config.keys() - or "trusted_certificates" in config.keys() - or "ssl_context" in config.keys())): + if (security_type in (SECURITY_TYPE_SELF_SIGNED_CERTIFICATE, + SECURITY_TYPE_SECURE) + and ("encrypted" in config + or "trust" in config + or "trusted_certificates" in config + or "ssl_context" in config)): # TODO: 6.0 - remove "trust" from error message raise ConfigurationError( diff --git a/src/neo4j/_sync/io/_bolt.py b/src/neo4j/_sync/io/_bolt.py index 69caaa6c2..2b4b541aa 100644 --- a/src/neo4j/_sync/io/_bolt.py +++ b/src/neo4j/_sync/io/_bolt.py @@ -28,13 +28,13 @@ from ..._async_compat.util import Util from ..._codec.hydration import v1 as hydration_v1 from ..._codec.packstream import v1 as packstream_v1 -from ..._conf import PoolConfig from ..._deadline import Deadline from ..._exceptions import ( BoltError, BoltHandshakeError, ) from ..._meta import USER_AGENT +from ..._sync.config import PoolConfig from ...addressing import ResolvedAddress from ...api import ( Auth, @@ -49,6 +49,7 @@ ServiceUnavailable, SessionExpired, ) +from ..config import PoolConfig from ._common import ( CommitResponse, Inbox, diff --git a/src/neo4j/_sync/io/_pool.py b/src/neo4j/_sync/io/_pool.py index baabfd41b..0a8464d22 100644 --- a/src/neo4j/_sync/io/_pool.py +++ b/src/neo4j/_sync/io/_pool.py @@ -36,10 +36,7 @@ ) from ..._async_compat.network import NetworkUtil from ..._async_compat.util import Util -from ..._conf import ( - PoolConfig, - WorkspaceConfig, -) +from ..._conf import WorkspaceConfig from ..._deadline import ( connection_deadline, Deadline, @@ -61,6 +58,7 @@ SessionExpired, WriteServiceUnavailable, ) +from ..config import PoolConfig from ._bolt import Bolt diff --git a/src/neo4j/_sync/work/result.py b/src/neo4j/_sync/work/result.py index 4714b6915..7de984146 100644 --- a/src/neo4j/_sync/work/result.py +++ b/src/neo4j/_sync/work/result.py @@ -590,10 +590,10 @@ def value( was obtained has been closed or the Result has been explicitly consumed. + .. seealso:: :meth:`.Record.value` + .. versionchanged:: 5.0 Can raise :exc:`.ResultConsumedError`. - - .. seealso:: :meth:`.Record.value` """ return [record.value(key, default) for record in self] @@ -611,10 +611,10 @@ def values( was obtained has been closed or the Result has been explicitly consumed. + .. seealso:: :meth:`.Record.values` + .. versionchanged:: 5.0 Can raise :exc:`.ResultConsumedError`. - - .. seealso:: :meth:`.Record.values` """ return [record.values(*keys) for record in self] @@ -640,10 +640,10 @@ def data(self, *keys: _TResultKey) -> t.List[t.Dict[str, t.Any]]: was obtained has been closed or the Result has been explicitly consumed. + .. seealso:: :meth:`.Record.data` + .. versionchanged:: 5.0 Can raise :exc:`.ResultConsumedError`. - - .. seealso:: :meth:`.Record.data` """ return [record.data(*keys) for record in self] diff --git a/src/neo4j/_work/summary.py b/src/neo4j/_work/summary.py index 6d0b18fe9..e02dec1a8 100644 --- a/src/neo4j/_work/summary.py +++ b/src/neo4j/_work/summary.py @@ -112,9 +112,9 @@ def __init__(self, address: Address, **metadata: t.Any) -> None: def summary_notifications(self) -> t.List[SummaryNotification]: """The same as ``notifications`` but in a parsed, structured form. - .. versionadded:: 5.7 - .. seealso:: :attr:`.notifications`, :class:`.SummaryNotification` + + .. versionadded:: 5.7 """ if getattr(self, "_summary_notifications", None) is None: self._summary_notifications = [ @@ -236,9 +236,9 @@ def contains_system_updates(self) -> bool: class SummaryNotification: """Structured form of a notification received from the server. - .. versionadded:: 5.7 - .. seealso:: :attr:`.ResultSummary.summary_notifications` + + .. versionadded:: 5.7 """ title: str = "" @@ -280,9 +280,9 @@ def _from_metadata(cls, metadata): class SummaryNotificationPosition: """Structured form of a notification position received from the server. - .. versionadded:: 5.7 - .. seealso:: :class:`.SummaryNotification` + + .. versionadded:: 5.7 """ #: The line number of the notification. Line numbers start at 1. diff --git a/src/neo4j/addressing.py b/src/neo4j/addressing.py index ae62b9b99..f18ad1424 100644 --- a/src/neo4j/addressing.py +++ b/src/neo4j/addressing.py @@ -148,7 +148,7 @@ def parse( :param default_port: The default port to use if none is specified. :data:`None` indicates to use ``0`` as default. - :return: The parsed address. + :returns: The parsed address. """ if not isinstance(s, str): raise TypeError("Address.parse requires a string argument") @@ -198,7 +198,7 @@ def parse_list( :param default_port: The default port to use if none is specified. :data:`None` indicates to use ``0`` as default. - :return: The list of parsed addresses. + :returns: The list of parsed addresses. """ if not all(isinstance(s0, str) for s0 in s): raise TypeError("Address.parse_list requires a string argument") diff --git a/src/neo4j/auth_management.py b/src/neo4j/auth_management.py index 43cce8d78..d0f534ed9 100644 --- a/src/neo4j/auth_management.py +++ b/src/neo4j/auth_management.py @@ -14,19 +14,37 @@ # limitations under the License. -from ._async.auth_management import AsyncAuthManagers +from ._async.auth_management import ( + AsyncAuthManagers, + AsyncClientCertificateProviders, + AsyncRotatingClientCertificateProvider, +) from ._auth_management import ( AsyncAuthManager, + AsyncClientCertificateProvider, AuthManager, + ClientCertificate, + ClientCertificateProvider, ExpiringAuth, ) -from ._sync.auth_management import AuthManagers +from ._sync.auth_management import ( + AuthManagers, + ClientCertificateProviders, + RotatingClientCertificateProvider, +) __all__ = [ "AsyncAuthManager", "AsyncAuthManagers", + "AsyncClientCertificateProvider", + "AsyncClientCertificateProviders", + "AsyncRotatingClientCertificateProvider", "AuthManager", "AuthManagers", + "ClientCertificate", + "ClientCertificateProvider", + "ClientCertificateProviders", "ExpiringAuth", + "RotatingClientCertificateProvider", ] diff --git a/src/neo4j/conf.py b/src/neo4j/conf.py index 74c47aa89..d7336b66a 100644 --- a/src/neo4j/conf.py +++ b/src/neo4j/conf.py @@ -22,13 +22,13 @@ DeprecatedAlias, DeprecatedAlternative, iter_items, - PoolConfig, RoutingConfig, SessionConfig, TransactionConfig, WorkspaceConfig, ) from ._meta import deprecation_warn as _deprecation_warn +from ._sync.config import PoolConfig __all__ = [ diff --git a/testkitbackend/_async/backend.py b/testkitbackend/_async/backend.py index 0d69ab7c7..a9e0b0681 100644 --- a/testkitbackend/_async/backend.py +++ b/testkitbackend/_async/backend.py @@ -59,6 +59,8 @@ def __init__(self, rd, wr): self.auth_token_on_expiration_supplies = {} self.basic_auth_token_supplies = {} self.expiring_auth_token_supplies = {} + self.client_cert_providers = {} + self.client_cert_supplies = {} self.bookmark_managers = {} self.bookmarks_consumptions = {} self.bookmarks_supplies = {} diff --git a/testkitbackend/_async/requests.py b/testkitbackend/_async/requests.py index 405df1d0a..b0ebcdd78 100644 --- a/testkitbackend/_async/requests.py +++ b/testkitbackend/_async/requests.py @@ -18,6 +18,7 @@ import json import re import ssl +import typing as t import warnings from os import path @@ -27,9 +28,11 @@ import neo4j.api import neo4j.auth_management from neo4j._async_compat.util import AsyncUtil +from neo4j._auth_management import ClientCertificate from neo4j.auth_management import ( AsyncAuthManager, AsyncAuthManagers, + AsyncClientCertificateProvider, ExpiringAuth, ) @@ -38,6 +41,7 @@ test_subtest_skips, totestkit, ) +from .._warning_check import warnings_check from ..exceptions import MarkdAsDriverException @@ -100,12 +104,29 @@ async def GetFeatures(backend, data): async def NewDriver(backend, data): + expected_warnings = [] + auth = fromtestkit.to_auth_token(data, "authorizationToken") if auth is None and data.get("authTokenManagerId") is not None: auth = backend.auth_token_managers[data["authTokenManagerId"]] else: data.mark_item_as_read_if_equals("authTokenManagerId", None) kwargs = {} + client_cert_provider_id = data.get("clientCertificateProviderId") + if client_cert_provider_id is not None: + kwargs["client_certificate"] = \ + backend.client_cert_providers[client_cert_provider_id] + data.mark_item_as_read_if_equals("clientCertificate", None) + expected_warnings.append( + (neo4j.PreviewWarning, "Mutual TLS is a preview feature.") + ) + else: + client_cert = fromtestkit.to_client_cert(data, "clientCertificate") + if client_cert is not None: + kwargs["client_certificate"] = client_cert + expected_warnings.append( + (neo4j.PreviewWarning, "Mutual TLS is a preview feature.") + ) if data["resolverRegistered"] or data["domainNameResolverRegistered"]: kwargs["resolver"] = resolution_func( backend, data["resolverRegistered"], @@ -145,9 +166,10 @@ async def NewDriver(backend, data): kwargs["trusted_certificates"] = neo4j.TrustCustomCAs(*cert_paths) fromtestkit.set_notifications_config(kwargs, data) - driver = neo4j.AsyncGraphDatabase.driver( - data["uri"], auth=auth, user_agent=data["userAgent"], **kwargs, - ) + with warnings_check(expected_warnings): + driver = neo4j.AsyncGraphDatabase.driver( + data["uri"], auth=auth, user_agent=data["userAgent"], **kwargs, + ) key = backend.next_key() backend.drivers[key] = driver await backend.send_response("Driver", {"id": key}) @@ -300,6 +322,59 @@ async def BearerAuthTokenProviderCompleted(backend, data): backend.expiring_auth_token_supplies[data["requestId"]] = expiring_auth +class TestKitClientCertificateProvider(AsyncClientCertificateProvider): + def __init__(self, backend): + self.id = backend.next_key() + self._backend = backend + + async def get_certificate(self) -> t.Optional[ClientCertificate]: + request_id = self._backend.next_key() + await self._backend.send_response( + "ClientCertificateProviderRequest", + { + "id": request_id, + "clientCertificateProviderId": self.id, + } + ) + if not await self._backend.process_request(): + # connection was closed before end of next message + return None + if request_id not in self._backend.client_cert_supplies: + raise RuntimeError( + "Backend did not receive expected " + "ClientCertificateProviderCompleted message for id " + f"{request_id}" + ) + return self._backend.client_cert_supplies.pop(request_id) + + +async def NewClientCertificateProvider(backend, data): + provider = TestKitClientCertificateProvider(backend) + backend.client_cert_providers[provider.id] = provider + await backend.send_response( + "ClientCertificateProvider", {"id": provider.id} + ) + + +async def ClientCertificateProviderClose(backend, data): + client_cert_provider_id = data["id"] + del backend.client_cert_providers[client_cert_provider_id] + await backend.send_response( + "ClientCertificateProvider", {"id": client_cert_provider_id} + ) + + +async def ClientCertificateProviderCompleted(backend, data): + has_update = data["hasUpdate"] + request_id = data["requestId"] + if not has_update: + data.mark_item_as_read("clientCertificate", recursive=True) + backend.client_cert_supplies[request_id] = None + return + client_cert = fromtestkit.to_client_cert(data, "clientCertificate") + backend.client_cert_supplies[request_id] = client_cert + + async def VerifyConnectivity(backend, data): driver_id = data["driverId"] driver = backend.drivers[driver_id] diff --git a/testkitbackend/_sync/backend.py b/testkitbackend/_sync/backend.py index 1e5bef247..fe3d24ce0 100644 --- a/testkitbackend/_sync/backend.py +++ b/testkitbackend/_sync/backend.py @@ -59,6 +59,8 @@ def __init__(self, rd, wr): self.auth_token_on_expiration_supplies = {} self.basic_auth_token_supplies = {} self.expiring_auth_token_supplies = {} + self.client_cert_providers = {} + self.client_cert_supplies = {} self.bookmark_managers = {} self.bookmarks_consumptions = {} self.bookmarks_supplies = {} diff --git a/testkitbackend/_sync/requests.py b/testkitbackend/_sync/requests.py index 3e64f77d8..c822a50a0 100644 --- a/testkitbackend/_sync/requests.py +++ b/testkitbackend/_sync/requests.py @@ -18,6 +18,7 @@ import json import re import ssl +import typing as t import warnings from os import path @@ -27,9 +28,11 @@ import neo4j.api import neo4j.auth_management from neo4j._async_compat.util import Util +from neo4j._auth_management import ClientCertificate from neo4j.auth_management import ( AuthManager, AuthManagers, + ClientCertificateProvider, ExpiringAuth, ) @@ -38,6 +41,7 @@ test_subtest_skips, totestkit, ) +from .._warning_check import warnings_check from ..exceptions import MarkdAsDriverException @@ -100,12 +104,29 @@ def GetFeatures(backend, data): def NewDriver(backend, data): + expected_warnings = [] + auth = fromtestkit.to_auth_token(data, "authorizationToken") if auth is None and data.get("authTokenManagerId") is not None: auth = backend.auth_token_managers[data["authTokenManagerId"]] else: data.mark_item_as_read_if_equals("authTokenManagerId", None) kwargs = {} + client_cert_provider_id = data.get("clientCertificateProviderId") + if client_cert_provider_id is not None: + kwargs["client_certificate"] = \ + backend.client_cert_providers[client_cert_provider_id] + data.mark_item_as_read_if_equals("clientCertificate", None) + expected_warnings.append( + (neo4j.PreviewWarning, "Mutual TLS is a preview feature.") + ) + else: + client_cert = fromtestkit.to_client_cert(data, "clientCertificate") + if client_cert is not None: + kwargs["client_certificate"] = client_cert + expected_warnings.append( + (neo4j.PreviewWarning, "Mutual TLS is a preview feature.") + ) if data["resolverRegistered"] or data["domainNameResolverRegistered"]: kwargs["resolver"] = resolution_func( backend, data["resolverRegistered"], @@ -145,9 +166,10 @@ def NewDriver(backend, data): kwargs["trusted_certificates"] = neo4j.TrustCustomCAs(*cert_paths) fromtestkit.set_notifications_config(kwargs, data) - driver = neo4j.GraphDatabase.driver( - data["uri"], auth=auth, user_agent=data["userAgent"], **kwargs, - ) + with warnings_check(expected_warnings): + driver = neo4j.GraphDatabase.driver( + data["uri"], auth=auth, user_agent=data["userAgent"], **kwargs, + ) key = backend.next_key() backend.drivers[key] = driver backend.send_response("Driver", {"id": key}) @@ -300,6 +322,59 @@ def BearerAuthTokenProviderCompleted(backend, data): backend.expiring_auth_token_supplies[data["requestId"]] = expiring_auth +class TestKitClientCertificateProvider(ClientCertificateProvider): + def __init__(self, backend): + self.id = backend.next_key() + self._backend = backend + + def get_certificate(self) -> t.Optional[ClientCertificate]: + request_id = self._backend.next_key() + self._backend.send_response( + "ClientCertificateProviderRequest", + { + "id": request_id, + "clientCertificateProviderId": self.id, + } + ) + if not self._backend.process_request(): + # connection was closed before end of next message + return None + if request_id not in self._backend.client_cert_supplies: + raise RuntimeError( + "Backend did not receive expected " + "ClientCertificateProviderCompleted message for id " + f"{request_id}" + ) + return self._backend.client_cert_supplies.pop(request_id) + + +def NewClientCertificateProvider(backend, data): + provider = TestKitClientCertificateProvider(backend) + backend.client_cert_providers[provider.id] = provider + backend.send_response( + "ClientCertificateProvider", {"id": provider.id} + ) + + +def ClientCertificateProviderClose(backend, data): + client_cert_provider_id = data["id"] + del backend.client_cert_providers[client_cert_provider_id] + backend.send_response( + "ClientCertificateProvider", {"id": client_cert_provider_id} + ) + + +def ClientCertificateProviderCompleted(backend, data): + has_update = data["hasUpdate"] + request_id = data["requestId"] + if not has_update: + data.mark_item_as_read("clientCertificate", recursive=True) + backend.client_cert_supplies[request_id] = None + return + client_cert = fromtestkit.to_client_cert(data, "clientCertificate") + backend.client_cert_supplies[request_id] = client_cert + + def VerifyConnectivity(backend, data): driver_id = data["driverId"] driver = backend.drivers[driver_id] diff --git a/testkitbackend/fromtestkit.py b/testkitbackend/fromtestkit.py index 43c468734..87ff65cc4 100644 --- a/testkitbackend/fromtestkit.py +++ b/testkitbackend/fromtestkit.py @@ -14,6 +14,9 @@ # limitations under the License. +from __future__ import annotations + +import typing as t from datetime import timedelta import pytz @@ -24,6 +27,7 @@ NotificationMinimumSeverity, Query, ) +from neo4j.auth_management import ClientCertificate from neo4j.spatial import ( CartesianPoint, WGS84Point, @@ -35,6 +39,8 @@ Time, ) +from ._warning_check import warnings_check + def to_cypher_and_params(data): from .backend import Request @@ -181,6 +187,19 @@ def to_auth_token(data, key): return auth +def to_client_cert(data, key) -> t.Optional[ClientCertificate]: + if data[key] is None: + return None + data[key].mark_item_as_read_if_equals("name", "ClientCertificate") + cert_data = data[key]["data"] + with warnings_check(( + (neo4j.PreviewWarning, "Mutual TLS is a preview feature."), + )): + return ClientCertificate( + cert_data["certfile"], cert_data["keyfile"], cert_data["password"] + ) + + def set_notifications_config(config, data): if "notificationsMinSeverity" in data: config["notifications_min_severity"] = \ diff --git a/testkitbackend/test_config.json b/testkitbackend/test_config.json index 61915f04e..96197644a 100644 --- a/testkitbackend/test_config.json +++ b/testkitbackend/test_config.json @@ -36,6 +36,7 @@ "Feature:API:RetryableExceptions": true, "Feature:API:Session:AuthConfig": true, "Feature:API:Session:NotificationsConfig": true, + "Feature:API:SSLClientCertificate": true, "Feature:API:SSLConfig": true, "Feature:API:SSLSchemes": true, "Feature:API:Type.Spatial": true, diff --git a/tests/unit/async_/fixtures/fake_pool.py b/tests/unit/async_/fixtures/fake_pool.py index 9792b2b65..ef0b4d982 100644 --- a/tests/unit/async_/fixtures/fake_pool.py +++ b/tests/unit/async_/fixtures/fake_pool.py @@ -16,8 +16,8 @@ import pytest +from neo4j._async.config import AsyncPoolConfig from neo4j._async.io._pool import AsyncIOPool -from neo4j._conf import PoolConfig __all__ = [ @@ -31,7 +31,7 @@ def fake_pool(async_fake_connection_generator, mocker): assert not hasattr(pool, "acquired_connection_mocks") pool.buffered_connection_mocks = [] pool.acquired_connection_mocks = [] - pool.pool_config = PoolConfig() + pool.pool_config = AsyncPoolConfig() def acquire_side_effect(*_, **__): if pool.buffered_connection_mocks: diff --git a/tests/unit/async_/io/test_class_bolt3.py b/tests/unit/async_/io/test_class_bolt3.py index 8a14cb48c..42d38a9bd 100644 --- a/tests/unit/async_/io/test_class_bolt3.py +++ b/tests/unit/async_/io/test_class_bolt3.py @@ -22,8 +22,8 @@ import neo4j from neo4j._api import TelemetryAPI +from neo4j._async.config import AsyncPoolConfig from neo4j._async.io._bolt3 import AsyncBolt3 -from neo4j._conf import PoolConfig from neo4j._meta import USER_AGENT from neo4j.exceptions import ConfigurationError @@ -62,14 +62,14 @@ def test_conn_is_not_stale(fake_socket, set_stale): def test_db_extra_not_supported_in_begin(fake_socket): address = neo4j.Address(("127.0.0.1", 7687)) - connection = AsyncBolt3(address, fake_socket(address), PoolConfig.max_connection_lifetime) + connection = AsyncBolt3(address, fake_socket(address), AsyncPoolConfig.max_connection_lifetime) with pytest.raises(ConfigurationError): connection.begin(db="something") def test_db_extra_not_supported_in_run(fake_socket): address = neo4j.Address(("127.0.0.1", 7687)) - connection = AsyncBolt3(address, fake_socket(address), PoolConfig.max_connection_lifetime) + connection = AsyncBolt3(address, fake_socket(address), AsyncPoolConfig.max_connection_lifetime) with pytest.raises(ConfigurationError): connection.run("", db="something") @@ -78,7 +78,7 @@ def test_db_extra_not_supported_in_run(fake_socket): async def test_simple_discard(fake_socket): address = neo4j.Address(("127.0.0.1", 7687)) socket = fake_socket(address, AsyncBolt3.UNPACKER_CLS) - connection = AsyncBolt3(address, socket, PoolConfig.max_connection_lifetime) + connection = AsyncBolt3(address, socket, AsyncPoolConfig.max_connection_lifetime) connection.discard() await connection.send_all() tag, fields = await socket.pop_message() @@ -90,7 +90,7 @@ async def test_simple_discard(fake_socket): async def test_simple_pull(fake_socket): address = neo4j.Address(("127.0.0.1", 7687)) socket = fake_socket(address, AsyncBolt3.UNPACKER_CLS) - connection = AsyncBolt3(address, socket, PoolConfig.max_connection_lifetime) + connection = AsyncBolt3(address, socket, AsyncPoolConfig.max_connection_lifetime) connection.pull() await connection.send_all() tag, fields = await socket.pop_message() @@ -108,7 +108,7 @@ async def test_telemetry_message( address = neo4j.Address(("127.0.0.1", 7687)) socket = fake_socket(address, AsyncBolt3.UNPACKER_CLS) connection = AsyncBolt3( - address, socket, PoolConfig.max_connection_lifetime, + address, socket, AsyncPoolConfig.max_connection_lifetime, telemetry_disabled=driver_disabled ) if serv_enabled: @@ -135,7 +135,7 @@ async def test_hint_recv_timeout_seconds_gets_ignored( "hints": {"connection.recv_timeout_seconds": recv_timeout}, }) connection = AsyncBolt3( - address, sockets.client, PoolConfig.max_connection_lifetime + address, sockets.client, AsyncPoolConfig.max_connection_lifetime ) await connection.hello() sockets.client.settimeout.assert_not_called() @@ -163,7 +163,7 @@ async def test_credentials_are_not_logged( sockets.client.settimeout = mocker.Mock() await sockets.server.send_message(b"\x70", {"server": "Neo4j/4.3.4"}) connection = AsyncBolt3( - address, sockets.client, PoolConfig.max_connection_lifetime, auth=auth + address, sockets.client, AsyncPoolConfig.max_connection_lifetime, auth=auth ) with caplog.at_level(logging.DEBUG): await connection.hello() @@ -181,7 +181,7 @@ async def test_credentials_are_not_logged( def test_auth_message_raises_configuration_error(message, fake_socket): address = neo4j.Address(("127.0.0.1", 7687)) connection = AsyncBolt3(address, fake_socket(address), - PoolConfig.max_connection_lifetime) + AsyncPoolConfig.max_connection_lifetime) with pytest.raises(ConfigurationError, match="User switching is not supported"): getattr(connection, message)() @@ -196,7 +196,7 @@ def test_auth_message_raises_configuration_error(message, fake_socket): async def test_re_auth_noop(auth, fake_socket, mocker): address = neo4j.Address(("127.0.0.1", 7687)) connection = AsyncBolt3(address, fake_socket(address), - PoolConfig.max_connection_lifetime, auth=auth) + AsyncPoolConfig.max_connection_lifetime, auth=auth) logon_spy = mocker.spy(connection, "logon") logoff_spy = mocker.spy(connection, "logoff") res = connection.re_auth(auth, None) @@ -221,7 +221,7 @@ async def test_re_auth_noop(auth, fake_socket, mocker): async def test_re_auth(auth1, auth2, fake_socket): address = neo4j.Address(("127.0.0.1", 7687)) connection = AsyncBolt3(address, fake_socket(address), - PoolConfig.max_connection_lifetime, auth=auth1) + AsyncPoolConfig.max_connection_lifetime, auth=auth1) with pytest.raises(ConfigurationError, match="User switching is not supported"): connection.re_auth(auth2, None) @@ -245,7 +245,7 @@ def test_does_not_support_notification_filters(fake_socket, method, address = neo4j.Address(("127.0.0.1", 7687)) socket = fake_socket(address, AsyncBolt3.UNPACKER_CLS) connection = AsyncBolt3(address, socket, - PoolConfig.max_connection_lifetime) + AsyncPoolConfig.max_connection_lifetime) method = getattr(connection, method) with pytest.raises(ConfigurationError, match="Notification filtering"): method(*args, **kwargs) @@ -267,7 +267,7 @@ async def test_hello_does_not_support_notification_filters( address = neo4j.Address(("127.0.0.1", 7687)) socket = fake_socket(address, AsyncBolt3.UNPACKER_CLS) connection = AsyncBolt3( - address, socket, PoolConfig.max_connection_lifetime, + address, socket, AsyncPoolConfig.max_connection_lifetime, **kwargs ) with pytest.raises(ConfigurationError, match="Notification filtering"): diff --git a/tests/unit/async_/io/test_class_bolt4x0.py b/tests/unit/async_/io/test_class_bolt4x0.py index 6a0ae824f..80df60ff3 100644 --- a/tests/unit/async_/io/test_class_bolt4x0.py +++ b/tests/unit/async_/io/test_class_bolt4x0.py @@ -21,8 +21,8 @@ import neo4j from neo4j._api import TelemetryAPI +from neo4j._async.config import AsyncPoolConfig from neo4j._async.io._bolt4 import AsyncBolt4x0 -from neo4j._conf import PoolConfig from neo4j._meta import USER_AGENT from neo4j.exceptions import ConfigurationError @@ -63,7 +63,7 @@ def test_conn_is_not_stale(fake_socket, set_stale): async def test_db_extra_in_begin(fake_socket): address = neo4j.Address(("127.0.0.1", 7687)) socket = fake_socket(address, AsyncBolt4x0.UNPACKER_CLS) - connection = AsyncBolt4x0(address, socket, PoolConfig.max_connection_lifetime) + connection = AsyncBolt4x0(address, socket, AsyncPoolConfig.max_connection_lifetime) connection.begin(db="something") await connection.send_all() tag, fields = await socket.pop_message() @@ -76,7 +76,7 @@ async def test_db_extra_in_begin(fake_socket): async def test_db_extra_in_run(fake_socket): address = neo4j.Address(("127.0.0.1", 7687)) socket = fake_socket(address, AsyncBolt4x0.UNPACKER_CLS) - connection = AsyncBolt4x0(address, socket, PoolConfig.max_connection_lifetime) + connection = AsyncBolt4x0(address, socket, AsyncPoolConfig.max_connection_lifetime) connection.run("", {}, db="something") await connection.send_all() tag, fields = await socket.pop_message() @@ -91,7 +91,7 @@ async def test_db_extra_in_run(fake_socket): async def test_n_extra_in_discard(fake_socket): address = neo4j.Address(("127.0.0.1", 7687)) socket = fake_socket(address, AsyncBolt4x0.UNPACKER_CLS) - connection = AsyncBolt4x0(address, socket, PoolConfig.max_connection_lifetime) + connection = AsyncBolt4x0(address, socket, AsyncPoolConfig.max_connection_lifetime) connection.discard(n=666) await connection.send_all() tag, fields = await socket.pop_message() @@ -111,7 +111,7 @@ async def test_n_extra_in_discard(fake_socket): async def test_qid_extra_in_discard(fake_socket, test_input, expected): address = neo4j.Address(("127.0.0.1", 7687)) socket = fake_socket(address, AsyncBolt4x0.UNPACKER_CLS) - connection = AsyncBolt4x0(address, socket, PoolConfig.max_connection_lifetime) + connection = AsyncBolt4x0(address, socket, AsyncPoolConfig.max_connection_lifetime) connection.discard(qid=test_input) await connection.send_all() tag, fields = await socket.pop_message() @@ -131,7 +131,7 @@ async def test_qid_extra_in_discard(fake_socket, test_input, expected): async def test_n_and_qid_extras_in_discard(fake_socket, test_input, expected): address = neo4j.Address(("127.0.0.1", 7687)) socket = fake_socket(address, AsyncBolt4x0.UNPACKER_CLS) - connection = AsyncBolt4x0(address, socket, PoolConfig.max_connection_lifetime) + connection = AsyncBolt4x0(address, socket, AsyncPoolConfig.max_connection_lifetime) connection.discard(n=666, qid=test_input) await connection.send_all() tag, fields = await socket.pop_message() @@ -151,7 +151,7 @@ async def test_n_and_qid_extras_in_discard(fake_socket, test_input, expected): async def test_n_extra_in_pull(fake_socket, test_input, expected): address = neo4j.Address(("127.0.0.1", 7687)) socket = fake_socket(address, AsyncBolt4x0.UNPACKER_CLS) - connection = AsyncBolt4x0(address, socket, PoolConfig.max_connection_lifetime) + connection = AsyncBolt4x0(address, socket, AsyncPoolConfig.max_connection_lifetime) connection.pull(n=test_input) await connection.send_all() tag, fields = await socket.pop_message() @@ -171,7 +171,7 @@ async def test_n_extra_in_pull(fake_socket, test_input, expected): async def test_qid_extra_in_pull(fake_socket, test_input, expected): address = neo4j.Address(("127.0.0.1", 7687)) socket = fake_socket(address, AsyncBolt4x0.UNPACKER_CLS) - connection = AsyncBolt4x0(address, socket, PoolConfig.max_connection_lifetime) + connection = AsyncBolt4x0(address, socket, AsyncPoolConfig.max_connection_lifetime) connection.pull(qid=test_input) await connection.send_all() tag, fields = await socket.pop_message() @@ -184,7 +184,7 @@ async def test_qid_extra_in_pull(fake_socket, test_input, expected): async def test_n_and_qid_extras_in_pull(fake_socket): address = neo4j.Address(("127.0.0.1", 7687)) socket = fake_socket(address, AsyncBolt4x0.UNPACKER_CLS) - connection = AsyncBolt4x0(address, socket, PoolConfig.max_connection_lifetime) + connection = AsyncBolt4x0(address, socket, AsyncPoolConfig.max_connection_lifetime) connection.pull(n=666, qid=777) await connection.send_all() tag, fields = await socket.pop_message() @@ -203,7 +203,7 @@ async def test_telemetry_message( address = neo4j.Address(("127.0.0.1", 7687)) socket = fake_socket(address, AsyncBolt4x0.UNPACKER_CLS) connection = AsyncBolt4x0( - address, socket, PoolConfig.max_connection_lifetime, + address, socket, AsyncPoolConfig.max_connection_lifetime, telemetry_disabled=driver_disabled ) if serv_enabled: @@ -230,7 +230,7 @@ async def test_hint_recv_timeout_seconds_gets_ignored( "hints": {"connection.recv_timeout_seconds": recv_timeout}, }) connection = AsyncBolt4x0( - address, sockets.client, PoolConfig.max_connection_lifetime + address, sockets.client, AsyncPoolConfig.max_connection_lifetime ) await connection.hello() sockets.client.settimeout.assert_not_called() @@ -258,7 +258,7 @@ async def test_credentials_are_not_logged( sockets.client.settimeout = mocker.Mock() await sockets.server.send_message(b"\x70", {"server": "Neo4j/4.3.4"}) connection = AsyncBolt4x0( - address, sockets.client, PoolConfig.max_connection_lifetime, auth=auth + address, sockets.client, AsyncPoolConfig.max_connection_lifetime, auth=auth ) with caplog.at_level(logging.DEBUG): await connection.hello() @@ -276,7 +276,7 @@ async def test_credentials_are_not_logged( def test_auth_message_raises_configuration_error(message, fake_socket): address = neo4j.Address(("127.0.0.1", 7687)) connection = AsyncBolt4x0(address, fake_socket(address), - PoolConfig.max_connection_lifetime) + AsyncPoolConfig.max_connection_lifetime) with pytest.raises(ConfigurationError, match="User switching is not supported"): getattr(connection, message)() @@ -291,7 +291,7 @@ def test_auth_message_raises_configuration_error(message, fake_socket): async def test_re_auth_noop(auth, fake_socket, mocker): address = neo4j.Address(("127.0.0.1", 7687)) connection = AsyncBolt4x0(address, fake_socket(address), - PoolConfig.max_connection_lifetime, auth=auth) + AsyncPoolConfig.max_connection_lifetime, auth=auth) logon_spy = mocker.spy(connection, "logon") logoff_spy = mocker.spy(connection, "logoff") res = connection.re_auth(auth, None) @@ -316,7 +316,7 @@ async def test_re_auth_noop(auth, fake_socket, mocker): async def test_re_auth(auth1, auth2, fake_socket): address = neo4j.Address(("127.0.0.1", 7687)) connection = AsyncBolt4x0(address, fake_socket(address), - PoolConfig.max_connection_lifetime, auth=auth1) + AsyncPoolConfig.max_connection_lifetime, auth=auth1) with pytest.raises(ConfigurationError, match="User switching is not supported"): connection.re_auth(auth2, None) @@ -340,7 +340,7 @@ def test_does_not_support_notification_filters(fake_socket, method, address = neo4j.Address(("127.0.0.1", 7687)) socket = fake_socket(address, AsyncBolt4x0.UNPACKER_CLS) connection = AsyncBolt4x0(address, socket, - PoolConfig.max_connection_lifetime) + AsyncPoolConfig.max_connection_lifetime) method = getattr(connection, method) with pytest.raises(ConfigurationError, match="Notification filtering"): method(*args, **kwargs) @@ -362,7 +362,7 @@ async def test_hello_does_not_support_notification_filters( address = neo4j.Address(("127.0.0.1", 7687)) socket = fake_socket(address, AsyncBolt4x0.UNPACKER_CLS) connection = AsyncBolt4x0( - address, socket, PoolConfig.max_connection_lifetime, + address, socket, AsyncPoolConfig.max_connection_lifetime, **kwargs ) with pytest.raises(ConfigurationError, match="Notification filtering"): diff --git a/tests/unit/async_/io/test_class_bolt4x1.py b/tests/unit/async_/io/test_class_bolt4x1.py index 8b177c9b5..bbe190aab 100644 --- a/tests/unit/async_/io/test_class_bolt4x1.py +++ b/tests/unit/async_/io/test_class_bolt4x1.py @@ -21,8 +21,8 @@ import neo4j from neo4j._api import TelemetryAPI +from neo4j._async.config import AsyncPoolConfig from neo4j._async.io._bolt4 import AsyncBolt4x1 -from neo4j._conf import PoolConfig from neo4j._meta import USER_AGENT from neo4j.exceptions import ConfigurationError @@ -63,7 +63,7 @@ def test_conn_is_not_stale(fake_socket, set_stale): async def test_db_extra_in_begin(fake_socket): address = neo4j.Address(("127.0.0.1", 7687)) socket = fake_socket(address, AsyncBolt4x1.UNPACKER_CLS) - connection = AsyncBolt4x1(address, socket, PoolConfig.max_connection_lifetime) + connection = AsyncBolt4x1(address, socket, AsyncPoolConfig.max_connection_lifetime) connection.begin(db="something") await connection.send_all() tag, fields = await socket.pop_message() @@ -76,7 +76,7 @@ async def test_db_extra_in_begin(fake_socket): async def test_db_extra_in_run(fake_socket): address = neo4j.Address(("127.0.0.1", 7687)) socket = fake_socket(address, AsyncBolt4x1.UNPACKER_CLS) - connection = AsyncBolt4x1(address, socket, PoolConfig.max_connection_lifetime) + connection = AsyncBolt4x1(address, socket, AsyncPoolConfig.max_connection_lifetime) connection.run("", {}, db="something") await connection.send_all() tag, fields = await socket.pop_message() @@ -91,7 +91,7 @@ async def test_db_extra_in_run(fake_socket): async def test_n_extra_in_discard(fake_socket): address = neo4j.Address(("127.0.0.1", 7687)) socket = fake_socket(address, AsyncBolt4x1.UNPACKER_CLS) - connection = AsyncBolt4x1(address, socket, PoolConfig.max_connection_lifetime) + connection = AsyncBolt4x1(address, socket, AsyncPoolConfig.max_connection_lifetime) connection.discard(n=666) await connection.send_all() tag, fields = await socket.pop_message() @@ -111,7 +111,7 @@ async def test_n_extra_in_discard(fake_socket): async def test_qid_extra_in_discard(fake_socket, test_input, expected): address = neo4j.Address(("127.0.0.1", 7687)) socket = fake_socket(address, AsyncBolt4x1.UNPACKER_CLS) - connection = AsyncBolt4x1(address, socket, PoolConfig.max_connection_lifetime) + connection = AsyncBolt4x1(address, socket, AsyncPoolConfig.max_connection_lifetime) connection.discard(qid=test_input) await connection.send_all() tag, fields = await socket.pop_message() @@ -131,7 +131,7 @@ async def test_qid_extra_in_discard(fake_socket, test_input, expected): async def test_n_and_qid_extras_in_discard(fake_socket, test_input, expected): address = neo4j.Address(("127.0.0.1", 7687)) socket = fake_socket(address, AsyncBolt4x1.UNPACKER_CLS) - connection = AsyncBolt4x1(address, socket, PoolConfig.max_connection_lifetime) + connection = AsyncBolt4x1(address, socket, AsyncPoolConfig.max_connection_lifetime) connection.discard(n=666, qid=test_input) await connection.send_all() tag, fields = await socket.pop_message() @@ -151,7 +151,7 @@ async def test_n_and_qid_extras_in_discard(fake_socket, test_input, expected): async def test_n_extra_in_pull(fake_socket, test_input, expected): address = neo4j.Address(("127.0.0.1", 7687)) socket = fake_socket(address, AsyncBolt4x1.UNPACKER_CLS) - connection = AsyncBolt4x1(address, socket, PoolConfig.max_connection_lifetime) + connection = AsyncBolt4x1(address, socket, AsyncPoolConfig.max_connection_lifetime) connection.pull(n=test_input) await connection.send_all() tag, fields = await socket.pop_message() @@ -171,7 +171,7 @@ async def test_n_extra_in_pull(fake_socket, test_input, expected): async def test_qid_extra_in_pull(fake_socket, test_input, expected): address = neo4j.Address(("127.0.0.1", 7687)) socket = fake_socket(address, AsyncBolt4x1.UNPACKER_CLS) - connection = AsyncBolt4x1(address, socket, PoolConfig.max_connection_lifetime) + connection = AsyncBolt4x1(address, socket, AsyncPoolConfig.max_connection_lifetime) connection.pull(qid=test_input) await connection.send_all() tag, fields = await socket.pop_message() @@ -184,7 +184,7 @@ async def test_qid_extra_in_pull(fake_socket, test_input, expected): async def test_n_and_qid_extras_in_pull(fake_socket): address = neo4j.Address(("127.0.0.1", 7687)) socket = fake_socket(address, AsyncBolt4x1.UNPACKER_CLS) - connection = AsyncBolt4x1(address, socket, PoolConfig.max_connection_lifetime) + connection = AsyncBolt4x1(address, socket, AsyncPoolConfig.max_connection_lifetime) connection.pull(n=666, qid=777) await connection.send_all() tag, fields = await socket.pop_message() @@ -201,7 +201,7 @@ async def test_hello_passes_routing_metadata(fake_socket_pair): unpacker_cls=AsyncBolt4x1.UNPACKER_CLS) await sockets.server.send_message(b"\x70", {"server": "Neo4j/4.1.0"}) connection = AsyncBolt4x1( - address, sockets.client, PoolConfig.max_connection_lifetime, + address, sockets.client, AsyncPoolConfig.max_connection_lifetime, routing_context={"foo": "bar"} ) await connection.hello() @@ -221,7 +221,7 @@ async def test_telemetry_message( address = neo4j.Address(("127.0.0.1", 7687)) socket = fake_socket(address, AsyncBolt4x1.UNPACKER_CLS) connection = AsyncBolt4x1( - address, socket, PoolConfig.max_connection_lifetime, + address, socket, AsyncPoolConfig.max_connection_lifetime, telemetry_disabled=driver_disabled ) if serv_enabled: @@ -248,7 +248,7 @@ async def test_hint_recv_timeout_seconds_gets_ignored( "hints": {"connection.recv_timeout_seconds": recv_timeout}, }) connection = AsyncBolt4x1(address, sockets.client, - PoolConfig.max_connection_lifetime) + AsyncPoolConfig.max_connection_lifetime) await connection.hello() sockets.client.settimeout.assert_not_called() @@ -275,7 +275,7 @@ async def test_credentials_are_not_logged( sockets.client.settimeout = mocker.Mock() await sockets.server.send_message(b"\x70", {"server": "Neo4j/4.3.4"}) connection = AsyncBolt4x1( - address, sockets.client, PoolConfig.max_connection_lifetime, auth=auth + address, sockets.client, AsyncPoolConfig.max_connection_lifetime, auth=auth ) with caplog.at_level(logging.DEBUG): await connection.hello() @@ -293,7 +293,7 @@ async def test_credentials_are_not_logged( def test_auth_message_raises_configuration_error(message, fake_socket): address = neo4j.Address(("127.0.0.1", 7687)) connection = AsyncBolt4x1(address, fake_socket(address), - PoolConfig.max_connection_lifetime) + AsyncPoolConfig.max_connection_lifetime) with pytest.raises(ConfigurationError, match="User switching is not supported"): getattr(connection, message)() @@ -308,7 +308,7 @@ def test_auth_message_raises_configuration_error(message, fake_socket): async def test_re_auth_noop(auth, fake_socket, mocker): address = neo4j.Address(("127.0.0.1", 7687)) connection = AsyncBolt4x1(address, fake_socket(address), - PoolConfig.max_connection_lifetime, auth=auth) + AsyncPoolConfig.max_connection_lifetime, auth=auth) logon_spy = mocker.spy(connection, "logon") logoff_spy = mocker.spy(connection, "logoff") res = connection.re_auth(auth, None) @@ -333,7 +333,7 @@ async def test_re_auth_noop(auth, fake_socket, mocker): async def test_re_auth(auth1, auth2, fake_socket): address = neo4j.Address(("127.0.0.1", 7687)) connection = AsyncBolt4x1(address, fake_socket(address), - PoolConfig.max_connection_lifetime, auth=auth1) + AsyncPoolConfig.max_connection_lifetime, auth=auth1) with pytest.raises(ConfigurationError, match="User switching is not supported"): connection.re_auth(auth2, None) @@ -357,7 +357,7 @@ def test_does_not_support_notification_filters(fake_socket, method, address = neo4j.Address(("127.0.0.1", 7687)) socket = fake_socket(address, AsyncBolt4x1.UNPACKER_CLS) connection = AsyncBolt4x1(address, socket, - PoolConfig.max_connection_lifetime) + AsyncPoolConfig.max_connection_lifetime) method = getattr(connection, method) with pytest.raises(ConfigurationError, match="Notification filtering"): method(*args, **kwargs) @@ -379,7 +379,7 @@ async def test_hello_does_not_support_notification_filters( address = neo4j.Address(("127.0.0.1", 7687)) socket = fake_socket(address, AsyncBolt4x1.UNPACKER_CLS) connection = AsyncBolt4x1( - address, socket, PoolConfig.max_connection_lifetime, + address, socket, AsyncPoolConfig.max_connection_lifetime, **kwargs ) with pytest.raises(ConfigurationError, match="Notification filtering"): diff --git a/tests/unit/async_/io/test_class_bolt4x2.py b/tests/unit/async_/io/test_class_bolt4x2.py index 5fa82cd18..3bcab1bd8 100644 --- a/tests/unit/async_/io/test_class_bolt4x2.py +++ b/tests/unit/async_/io/test_class_bolt4x2.py @@ -21,8 +21,8 @@ import neo4j from neo4j._api import TelemetryAPI +from neo4j._async.config import AsyncPoolConfig from neo4j._async.io._bolt4 import AsyncBolt4x2 -from neo4j._conf import PoolConfig from neo4j._meta import USER_AGENT from neo4j.exceptions import ConfigurationError @@ -63,7 +63,7 @@ def test_conn_is_not_stale(fake_socket, set_stale): async def test_db_extra_in_begin(fake_socket): address = neo4j.Address(("127.0.0.1", 7687)) socket = fake_socket(address, AsyncBolt4x2.UNPACKER_CLS) - connection = AsyncBolt4x2(address, socket, PoolConfig.max_connection_lifetime) + connection = AsyncBolt4x2(address, socket, AsyncPoolConfig.max_connection_lifetime) connection.begin(db="something") await connection.send_all() tag, fields = await socket.pop_message() @@ -76,7 +76,7 @@ async def test_db_extra_in_begin(fake_socket): async def test_db_extra_in_run(fake_socket): address = neo4j.Address(("127.0.0.1", 7687)) socket = fake_socket(address, AsyncBolt4x2.UNPACKER_CLS) - connection = AsyncBolt4x2(address, socket, PoolConfig.max_connection_lifetime) + connection = AsyncBolt4x2(address, socket, AsyncPoolConfig.max_connection_lifetime) connection.run("", {}, db="something") await connection.send_all() tag, fields = await socket.pop_message() @@ -91,7 +91,7 @@ async def test_db_extra_in_run(fake_socket): async def test_n_extra_in_discard(fake_socket): address = neo4j.Address(("127.0.0.1", 7687)) socket = fake_socket(address, AsyncBolt4x2.UNPACKER_CLS) - connection = AsyncBolt4x2(address, socket, PoolConfig.max_connection_lifetime) + connection = AsyncBolt4x2(address, socket, AsyncPoolConfig.max_connection_lifetime) connection.discard(n=666) await connection.send_all() tag, fields = await socket.pop_message() @@ -111,7 +111,7 @@ async def test_n_extra_in_discard(fake_socket): async def test_qid_extra_in_discard(fake_socket, test_input, expected): address = neo4j.Address(("127.0.0.1", 7687)) socket = fake_socket(address, AsyncBolt4x2.UNPACKER_CLS) - connection = AsyncBolt4x2(address, socket, PoolConfig.max_connection_lifetime) + connection = AsyncBolt4x2(address, socket, AsyncPoolConfig.max_connection_lifetime) connection.discard(qid=test_input) await connection.send_all() tag, fields = await socket.pop_message() @@ -131,7 +131,7 @@ async def test_qid_extra_in_discard(fake_socket, test_input, expected): async def test_n_and_qid_extras_in_discard(fake_socket, test_input, expected): address = neo4j.Address(("127.0.0.1", 7687)) socket = fake_socket(address, AsyncBolt4x2.UNPACKER_CLS) - connection = AsyncBolt4x2(address, socket, PoolConfig.max_connection_lifetime) + connection = AsyncBolt4x2(address, socket, AsyncPoolConfig.max_connection_lifetime) connection.discard(n=666, qid=test_input) await connection.send_all() tag, fields = await socket.pop_message() @@ -151,7 +151,7 @@ async def test_n_and_qid_extras_in_discard(fake_socket, test_input, expected): async def test_n_extra_in_pull(fake_socket, test_input, expected): address = neo4j.Address(("127.0.0.1", 7687)) socket = fake_socket(address, AsyncBolt4x2.UNPACKER_CLS) - connection = AsyncBolt4x2(address, socket, PoolConfig.max_connection_lifetime) + connection = AsyncBolt4x2(address, socket, AsyncPoolConfig.max_connection_lifetime) connection.pull(n=test_input) await connection.send_all() tag, fields = await socket.pop_message() @@ -171,7 +171,7 @@ async def test_n_extra_in_pull(fake_socket, test_input, expected): async def test_qid_extra_in_pull(fake_socket, test_input, expected): address = neo4j.Address(("127.0.0.1", 7687)) socket = fake_socket(address, AsyncBolt4x2.UNPACKER_CLS) - connection = AsyncBolt4x2(address, socket, PoolConfig.max_connection_lifetime) + connection = AsyncBolt4x2(address, socket, AsyncPoolConfig.max_connection_lifetime) connection.pull(qid=test_input) await connection.send_all() tag, fields = await socket.pop_message() @@ -184,7 +184,7 @@ async def test_qid_extra_in_pull(fake_socket, test_input, expected): async def test_n_and_qid_extras_in_pull(fake_socket): address = neo4j.Address(("127.0.0.1", 7687)) socket = fake_socket(address, AsyncBolt4x2.UNPACKER_CLS) - connection = AsyncBolt4x2(address, socket, PoolConfig.max_connection_lifetime) + connection = AsyncBolt4x2(address, socket, AsyncPoolConfig.max_connection_lifetime) connection.pull(n=666, qid=777) await connection.send_all() tag, fields = await socket.pop_message() @@ -201,7 +201,7 @@ async def test_hello_passes_routing_metadata(fake_socket_pair): unpacker_cls=AsyncBolt4x2.UNPACKER_CLS) await sockets.server.send_message(b"\x70", {"server": "Neo4j/4.2.0"}) connection = AsyncBolt4x2( - address, sockets.client, PoolConfig.max_connection_lifetime, + address, sockets.client, AsyncPoolConfig.max_connection_lifetime, routing_context={"foo": "bar"} ) await connection.hello() @@ -221,7 +221,7 @@ async def test_telemetry_message( address = neo4j.Address(("127.0.0.1", 7687)) socket = fake_socket(address, AsyncBolt4x2.UNPACKER_CLS) connection = AsyncBolt4x2( - address, socket, PoolConfig.max_connection_lifetime, + address, socket, AsyncPoolConfig.max_connection_lifetime, telemetry_disabled=driver_disabled ) if serv_enabled: @@ -248,7 +248,7 @@ async def test_hint_recv_timeout_seconds_gets_ignored( "hints": {"connection.recv_timeout_seconds": recv_timeout}, }) connection = AsyncBolt4x2( - address, sockets.client, PoolConfig.max_connection_lifetime + address, sockets.client, AsyncPoolConfig.max_connection_lifetime ) await connection.hello() sockets.client.settimeout.assert_not_called() @@ -276,7 +276,7 @@ async def test_credentials_are_not_logged( sockets.client.settimeout = mocker.Mock() await sockets.server.send_message(b"\x70", {"server": "Neo4j/4.3.4"}) connection = AsyncBolt4x2( - address, sockets.client, PoolConfig.max_connection_lifetime, auth=auth + address, sockets.client, AsyncPoolConfig.max_connection_lifetime, auth=auth ) with caplog.at_level(logging.DEBUG): await connection.hello() @@ -294,7 +294,7 @@ async def test_credentials_are_not_logged( def test_auth_message_raises_configuration_error(message, fake_socket): address = neo4j.Address(("127.0.0.1", 7687)) connection = AsyncBolt4x2(address, fake_socket(address), - PoolConfig.max_connection_lifetime) + AsyncPoolConfig.max_connection_lifetime) with pytest.raises(ConfigurationError, match="User switching is not supported"): getattr(connection, message)() @@ -309,7 +309,7 @@ def test_auth_message_raises_configuration_error(message, fake_socket): async def test_re_auth_noop(auth, fake_socket, mocker): address = neo4j.Address(("127.0.0.1", 7687)) connection = AsyncBolt4x2(address, fake_socket(address), - PoolConfig.max_connection_lifetime, auth=auth) + AsyncPoolConfig.max_connection_lifetime, auth=auth) logon_spy = mocker.spy(connection, "logon") logoff_spy = mocker.spy(connection, "logoff") res = connection.re_auth(auth, None) @@ -334,7 +334,7 @@ async def test_re_auth_noop(auth, fake_socket, mocker): async def test_re_auth(auth1, auth2, fake_socket): address = neo4j.Address(("127.0.0.1", 7687)) connection = AsyncBolt4x2(address, fake_socket(address), - PoolConfig.max_connection_lifetime, auth=auth1) + AsyncPoolConfig.max_connection_lifetime, auth=auth1) with pytest.raises(ConfigurationError, match="User switching is not supported"): connection.re_auth(auth2, None) @@ -358,7 +358,7 @@ def test_does_not_support_notification_filters(fake_socket, method, address = neo4j.Address(("127.0.0.1", 7687)) socket = fake_socket(address, AsyncBolt4x2.UNPACKER_CLS) connection = AsyncBolt4x2(address, socket, - PoolConfig.max_connection_lifetime) + AsyncPoolConfig.max_connection_lifetime) method = getattr(connection, method) with pytest.raises(ConfigurationError, match="Notification filtering"): method(*args, **kwargs) @@ -380,7 +380,7 @@ async def test_hello_does_not_support_notification_filters( address = neo4j.Address(("127.0.0.1", 7687)) socket = fake_socket(address, AsyncBolt4x2.UNPACKER_CLS) connection = AsyncBolt4x2( - address, socket, PoolConfig.max_connection_lifetime, + address, socket, AsyncPoolConfig.max_connection_lifetime, **kwargs ) with pytest.raises(ConfigurationError, match="Notification filtering"): diff --git a/tests/unit/async_/io/test_class_bolt4x3.py b/tests/unit/async_/io/test_class_bolt4x3.py index 3e570c489..17fead964 100644 --- a/tests/unit/async_/io/test_class_bolt4x3.py +++ b/tests/unit/async_/io/test_class_bolt4x3.py @@ -21,8 +21,8 @@ import neo4j from neo4j._api import TelemetryAPI +from neo4j._async.config import AsyncPoolConfig from neo4j._async.io._bolt4 import AsyncBolt4x3 -from neo4j._conf import PoolConfig from neo4j._meta import USER_AGENT from neo4j.exceptions import ConfigurationError @@ -63,7 +63,7 @@ def test_conn_is_not_stale(fake_socket, set_stale): async def test_db_extra_in_begin(fake_socket): address = neo4j.Address(("127.0.0.1", 7687)) socket = fake_socket(address, AsyncBolt4x3.UNPACKER_CLS) - connection = AsyncBolt4x3(address, socket, PoolConfig.max_connection_lifetime) + connection = AsyncBolt4x3(address, socket, AsyncPoolConfig.max_connection_lifetime) connection.begin(db="something") await connection.send_all() tag, fields = await socket.pop_message() @@ -76,7 +76,7 @@ async def test_db_extra_in_begin(fake_socket): async def test_db_extra_in_run(fake_socket): address = neo4j.Address(("127.0.0.1", 7687)) socket = fake_socket(address, AsyncBolt4x3.UNPACKER_CLS) - connection = AsyncBolt4x3(address, socket, PoolConfig.max_connection_lifetime) + connection = AsyncBolt4x3(address, socket, AsyncPoolConfig.max_connection_lifetime) connection.run("", {}, db="something") await connection.send_all() tag, fields = await socket.pop_message() @@ -91,7 +91,7 @@ async def test_db_extra_in_run(fake_socket): async def test_n_extra_in_discard(fake_socket): address = neo4j.Address(("127.0.0.1", 7687)) socket = fake_socket(address, AsyncBolt4x3.UNPACKER_CLS) - connection = AsyncBolt4x3(address, socket, PoolConfig.max_connection_lifetime) + connection = AsyncBolt4x3(address, socket, AsyncPoolConfig.max_connection_lifetime) connection.discard(n=666) await connection.send_all() tag, fields = await socket.pop_message() @@ -111,7 +111,7 @@ async def test_n_extra_in_discard(fake_socket): async def test_qid_extra_in_discard(fake_socket, test_input, expected): address = neo4j.Address(("127.0.0.1", 7687)) socket = fake_socket(address, AsyncBolt4x3.UNPACKER_CLS) - connection = AsyncBolt4x3(address, socket, PoolConfig.max_connection_lifetime) + connection = AsyncBolt4x3(address, socket, AsyncPoolConfig.max_connection_lifetime) connection.discard(qid=test_input) await connection.send_all() tag, fields = await socket.pop_message() @@ -131,7 +131,7 @@ async def test_qid_extra_in_discard(fake_socket, test_input, expected): async def test_n_and_qid_extras_in_discard(fake_socket, test_input, expected): address = neo4j.Address(("127.0.0.1", 7687)) socket = fake_socket(address, AsyncBolt4x3.UNPACKER_CLS) - connection = AsyncBolt4x3(address, socket, PoolConfig.max_connection_lifetime) + connection = AsyncBolt4x3(address, socket, AsyncPoolConfig.max_connection_lifetime) connection.discard(n=666, qid=test_input) await connection.send_all() tag, fields = await socket.pop_message() @@ -151,7 +151,7 @@ async def test_n_and_qid_extras_in_discard(fake_socket, test_input, expected): async def test_n_extra_in_pull(fake_socket, test_input, expected): address = neo4j.Address(("127.0.0.1", 7687)) socket = fake_socket(address, AsyncBolt4x3.UNPACKER_CLS) - connection = AsyncBolt4x3(address, socket, PoolConfig.max_connection_lifetime) + connection = AsyncBolt4x3(address, socket, AsyncPoolConfig.max_connection_lifetime) connection.pull(n=test_input) await connection.send_all() tag, fields = await socket.pop_message() @@ -171,7 +171,7 @@ async def test_n_extra_in_pull(fake_socket, test_input, expected): async def test_qid_extra_in_pull(fake_socket, test_input, expected): address = neo4j.Address(("127.0.0.1", 7687)) socket = fake_socket(address, AsyncBolt4x3.UNPACKER_CLS) - connection = AsyncBolt4x3(address, socket, PoolConfig.max_connection_lifetime) + connection = AsyncBolt4x3(address, socket, AsyncPoolConfig.max_connection_lifetime) connection.pull(qid=test_input) await connection.send_all() tag, fields = await socket.pop_message() @@ -184,7 +184,7 @@ async def test_qid_extra_in_pull(fake_socket, test_input, expected): async def test_n_and_qid_extras_in_pull(fake_socket): address = neo4j.Address(("127.0.0.1", 7687)) socket = fake_socket(address, AsyncBolt4x3.UNPACKER_CLS) - connection = AsyncBolt4x3(address, socket, PoolConfig.max_connection_lifetime) + connection = AsyncBolt4x3(address, socket, AsyncPoolConfig.max_connection_lifetime) connection.pull(n=666, qid=777) await connection.send_all() tag, fields = await socket.pop_message() @@ -201,7 +201,7 @@ async def test_hello_passes_routing_metadata(fake_socket_pair): unpacker_cls=AsyncBolt4x3.UNPACKER_CLS) await sockets.server.send_message(b"\x70", {"server": "Neo4j/4.3.0"}) connection = AsyncBolt4x3( - address, sockets.client, PoolConfig.max_connection_lifetime, + address, sockets.client, AsyncPoolConfig.max_connection_lifetime, routing_context={"foo": "bar"} ) await connection.hello() @@ -221,7 +221,7 @@ async def test_telemetry_message( address = neo4j.Address(("127.0.0.1", 7687)) socket = fake_socket(address, AsyncBolt4x3.UNPACKER_CLS) connection = AsyncBolt4x3( - address, socket, PoolConfig.max_connection_lifetime, + address, socket, AsyncPoolConfig.max_connection_lifetime, telemetry_disabled=driver_disabled ) if serv_enabled: @@ -258,7 +258,7 @@ async def test_hint_recv_timeout_seconds( b"\x70", {"server": "Neo4j/4.3.0", "hints": hints} ) connection = AsyncBolt4x3( - address, sockets.client, PoolConfig.max_connection_lifetime + address, sockets.client, AsyncPoolConfig.max_connection_lifetime ) with caplog.at_level(logging.INFO): await connection.hello() @@ -302,7 +302,7 @@ async def test_credentials_are_not_logged( sockets.client.settimeout = mocker.Mock() await sockets.server.send_message(b"\x70", {"server": "Neo4j/4.3.4"}) connection = AsyncBolt4x3( - address, sockets.client, PoolConfig.max_connection_lifetime, auth=auth + address, sockets.client, AsyncPoolConfig.max_connection_lifetime, auth=auth ) with caplog.at_level(logging.DEBUG): await connection.hello() @@ -321,7 +321,7 @@ async def test_credentials_are_not_logged( def test_auth_message_raises_configuration_error(message, fake_socket): address = neo4j.Address(("127.0.0.1", 7687)) connection = AsyncBolt4x3(address, fake_socket(address), - PoolConfig.max_connection_lifetime) + AsyncPoolConfig.max_connection_lifetime) with pytest.raises(ConfigurationError, match="User switching is not supported"): getattr(connection, message)() @@ -336,7 +336,7 @@ def test_auth_message_raises_configuration_error(message, fake_socket): async def test_re_auth_noop(auth, fake_socket, mocker): address = neo4j.Address(("127.0.0.1", 7687)) connection = AsyncBolt4x3(address, fake_socket(address), - PoolConfig.max_connection_lifetime, auth=auth) + AsyncPoolConfig.max_connection_lifetime, auth=auth) logon_spy = mocker.spy(connection, "logon") logoff_spy = mocker.spy(connection, "logoff") res = connection.re_auth(auth, None) @@ -361,7 +361,7 @@ async def test_re_auth_noop(auth, fake_socket, mocker): async def test_re_auth(auth1, auth2, fake_socket): address = neo4j.Address(("127.0.0.1", 7687)) connection = AsyncBolt4x3(address, fake_socket(address), - PoolConfig.max_connection_lifetime, auth=auth1) + AsyncPoolConfig.max_connection_lifetime, auth=auth1) with pytest.raises(ConfigurationError, match="User switching is not supported"): connection.re_auth(auth2, None) @@ -385,7 +385,7 @@ def test_does_not_support_notification_filters(fake_socket, method, address = neo4j.Address(("127.0.0.1", 7687)) socket = fake_socket(address, AsyncBolt4x3.UNPACKER_CLS) connection = AsyncBolt4x3(address, socket, - PoolConfig.max_connection_lifetime) + AsyncPoolConfig.max_connection_lifetime) method = getattr(connection, method) with pytest.raises(ConfigurationError, match="Notification filtering"): method(*args, **kwargs) @@ -407,7 +407,7 @@ async def test_hello_does_not_support_notification_filters( address = neo4j.Address(("127.0.0.1", 7687)) socket = fake_socket(address, AsyncBolt4x3.UNPACKER_CLS) connection = AsyncBolt4x3( - address, socket, PoolConfig.max_connection_lifetime, + address, socket, AsyncPoolConfig.max_connection_lifetime, **kwargs ) with pytest.raises(ConfigurationError, match="Notification filtering"): diff --git a/tests/unit/async_/io/test_class_bolt4x4.py b/tests/unit/async_/io/test_class_bolt4x4.py index 230028ccf..2b615b894 100644 --- a/tests/unit/async_/io/test_class_bolt4x4.py +++ b/tests/unit/async_/io/test_class_bolt4x4.py @@ -21,8 +21,8 @@ import neo4j from neo4j._api import TelemetryAPI +from neo4j._async.config import AsyncPoolConfig from neo4j._async.io._bolt4 import AsyncBolt4x4 -from neo4j._conf import PoolConfig from neo4j._meta import USER_AGENT from neo4j.exceptions import ConfigurationError @@ -72,7 +72,7 @@ def test_conn_is_not_stale(fake_socket, set_stale): async def test_extra_in_begin(fake_socket, args, kwargs, expected_fields): address = neo4j.Address(("127.0.0.1", 7687)) socket = fake_socket(address, AsyncBolt4x4.UNPACKER_CLS) - connection = AsyncBolt4x4(address, socket, PoolConfig.max_connection_lifetime) + connection = AsyncBolt4x4(address, socket, AsyncPoolConfig.max_connection_lifetime) connection.begin(*args, **kwargs) await connection.send_all() tag, is_fields = await socket.pop_message() @@ -93,7 +93,7 @@ async def test_extra_in_begin(fake_socket, args, kwargs, expected_fields): async def test_extra_in_run(fake_socket, args, kwargs, expected_fields): address = neo4j.Address(("127.0.0.1", 7687)) socket = fake_socket(address, AsyncBolt4x4.UNPACKER_CLS) - connection = AsyncBolt4x4(address, socket, PoolConfig.max_connection_lifetime) + connection = AsyncBolt4x4(address, socket, AsyncPoolConfig.max_connection_lifetime) connection.run(*args, **kwargs) await connection.send_all() tag, is_fields = await socket.pop_message() @@ -105,7 +105,7 @@ async def test_extra_in_run(fake_socket, args, kwargs, expected_fields): async def test_n_extra_in_discard(fake_socket): address = neo4j.Address(("127.0.0.1", 7687)) socket = fake_socket(address, AsyncBolt4x4.UNPACKER_CLS) - connection = AsyncBolt4x4(address, socket, PoolConfig.max_connection_lifetime) + connection = AsyncBolt4x4(address, socket, AsyncPoolConfig.max_connection_lifetime) connection.discard(n=666) await connection.send_all() tag, fields = await socket.pop_message() @@ -125,7 +125,7 @@ async def test_n_extra_in_discard(fake_socket): async def test_qid_extra_in_discard(fake_socket, test_input, expected): address = neo4j.Address(("127.0.0.1", 7687)) socket = fake_socket(address, AsyncBolt4x4.UNPACKER_CLS) - connection = AsyncBolt4x4(address, socket, PoolConfig.max_connection_lifetime) + connection = AsyncBolt4x4(address, socket, AsyncPoolConfig.max_connection_lifetime) connection.discard(qid=test_input) await connection.send_all() tag, fields = await socket.pop_message() @@ -145,7 +145,7 @@ async def test_qid_extra_in_discard(fake_socket, test_input, expected): async def test_n_and_qid_extras_in_discard(fake_socket, test_input, expected): address = neo4j.Address(("127.0.0.1", 7687)) socket = fake_socket(address, AsyncBolt4x4.UNPACKER_CLS) - connection = AsyncBolt4x4(address, socket, PoolConfig.max_connection_lifetime) + connection = AsyncBolt4x4(address, socket, AsyncPoolConfig.max_connection_lifetime) connection.discard(n=666, qid=test_input) await connection.send_all() tag, fields = await socket.pop_message() @@ -165,7 +165,7 @@ async def test_n_and_qid_extras_in_discard(fake_socket, test_input, expected): async def test_n_extra_in_pull(fake_socket, test_input, expected): address = neo4j.Address(("127.0.0.1", 7687)) socket = fake_socket(address, AsyncBolt4x4.UNPACKER_CLS) - connection = AsyncBolt4x4(address, socket, PoolConfig.max_connection_lifetime) + connection = AsyncBolt4x4(address, socket, AsyncPoolConfig.max_connection_lifetime) connection.pull(n=test_input) await connection.send_all() tag, fields = await socket.pop_message() @@ -185,7 +185,7 @@ async def test_n_extra_in_pull(fake_socket, test_input, expected): async def test_qid_extra_in_pull(fake_socket, test_input, expected): address = neo4j.Address(("127.0.0.1", 7687)) socket = fake_socket(address, AsyncBolt4x4.UNPACKER_CLS) - connection = AsyncBolt4x4(address, socket, PoolConfig.max_connection_lifetime) + connection = AsyncBolt4x4(address, socket, AsyncPoolConfig.max_connection_lifetime) connection.pull(qid=test_input) await connection.send_all() tag, fields = await socket.pop_message() @@ -198,7 +198,7 @@ async def test_qid_extra_in_pull(fake_socket, test_input, expected): async def test_n_and_qid_extras_in_pull(fake_socket): address = neo4j.Address(("127.0.0.1", 7687)) socket = fake_socket(address, AsyncBolt4x4.UNPACKER_CLS) - connection = AsyncBolt4x4(address, socket, PoolConfig.max_connection_lifetime) + connection = AsyncBolt4x4(address, socket, AsyncPoolConfig.max_connection_lifetime) connection.pull(n=666, qid=777) await connection.send_all() tag, fields = await socket.pop_message() @@ -215,7 +215,7 @@ async def test_hello_passes_routing_metadata(fake_socket_pair): unpacker_cls=AsyncBolt4x4.UNPACKER_CLS) await sockets.server.send_message(b"\x70", {"server": "Neo4j/4.4.0"}) connection = AsyncBolt4x4( - address, sockets.client, PoolConfig.max_connection_lifetime, + address, sockets.client, AsyncPoolConfig.max_connection_lifetime, routing_context={"foo": "bar"} ) await connection.hello() @@ -235,7 +235,7 @@ async def test_telemetry_message( address = neo4j.Address(("127.0.0.1", 7687)) socket = fake_socket(address, AsyncBolt4x4.UNPACKER_CLS) connection = AsyncBolt4x4( - address, socket, PoolConfig.max_connection_lifetime, + address, socket, AsyncPoolConfig.max_connection_lifetime, telemetry_disabled=driver_disabled ) if serv_enabled: @@ -272,7 +272,7 @@ async def test_hint_recv_timeout_seconds( b"\x70", {"server": "Neo4j/4.3.4", "hints": hints} ) connection = AsyncBolt4x4( - address, sockets.client, PoolConfig.max_connection_lifetime + address, sockets.client, AsyncPoolConfig.max_connection_lifetime ) with caplog.at_level(logging.INFO): await connection.hello() @@ -316,7 +316,7 @@ async def test_credentials_are_not_logged( sockets.client.settimeout = mocker.Mock() await sockets.server.send_message(b"\x70", {"server": "Neo4j/4.3.4"}) connection = AsyncBolt4x4( - address, sockets.client, PoolConfig.max_connection_lifetime, auth=auth + address, sockets.client, AsyncPoolConfig.max_connection_lifetime, auth=auth ) with caplog.at_level(logging.DEBUG): await connection.hello() @@ -334,7 +334,7 @@ async def test_credentials_are_not_logged( def test_auth_message_raises_configuration_error(message, fake_socket): address = neo4j.Address(("127.0.0.1", 7687)) connection = AsyncBolt4x4(address, fake_socket(address), - PoolConfig.max_connection_lifetime) + AsyncPoolConfig.max_connection_lifetime) with pytest.raises(ConfigurationError, match="User switching is not supported"): getattr(connection, message)() @@ -349,7 +349,7 @@ def test_auth_message_raises_configuration_error(message, fake_socket): async def test_re_auth_noop(auth, fake_socket, mocker): address = neo4j.Address(("127.0.0.1", 7687)) connection = AsyncBolt4x4(address, fake_socket(address), - PoolConfig.max_connection_lifetime, auth=auth) + AsyncPoolConfig.max_connection_lifetime, auth=auth) logon_spy = mocker.spy(connection, "logon") logoff_spy = mocker.spy(connection, "logoff") res = connection.re_auth(auth, None) @@ -374,7 +374,7 @@ async def test_re_auth_noop(auth, fake_socket, mocker): async def test_re_auth(auth1, auth2, fake_socket): address = neo4j.Address(("127.0.0.1", 7687)) connection = AsyncBolt4x4(address, fake_socket(address), - PoolConfig.max_connection_lifetime, auth=auth1) + AsyncPoolConfig.max_connection_lifetime, auth=auth1) with pytest.raises(ConfigurationError, match="User switching is not supported"): connection.re_auth(auth2, None) @@ -398,7 +398,7 @@ def test_does_not_support_notification_filters(fake_socket, method, address = neo4j.Address(("127.0.0.1", 7687)) socket = fake_socket(address, AsyncBolt4x4.UNPACKER_CLS) connection = AsyncBolt4x4(address, socket, - PoolConfig.max_connection_lifetime) + AsyncPoolConfig.max_connection_lifetime) method = getattr(connection, method) with pytest.raises(ConfigurationError, match="Notification filtering"): method(*args, **kwargs) @@ -420,7 +420,7 @@ async def test_hello_does_not_support_notification_filters( address = neo4j.Address(("127.0.0.1", 7687)) socket = fake_socket(address, AsyncBolt4x4.UNPACKER_CLS) connection = AsyncBolt4x4( - address, socket, PoolConfig.max_connection_lifetime, + address, socket, AsyncPoolConfig.max_connection_lifetime, **kwargs ) with pytest.raises(ConfigurationError, match="Notification filtering"): diff --git a/tests/unit/async_/io/test_class_bolt5x0.py b/tests/unit/async_/io/test_class_bolt5x0.py index bce9a7e64..5b01a2760 100644 --- a/tests/unit/async_/io/test_class_bolt5x0.py +++ b/tests/unit/async_/io/test_class_bolt5x0.py @@ -21,8 +21,8 @@ import neo4j from neo4j._api import TelemetryAPI +from neo4j._async.config import AsyncPoolConfig from neo4j._async.io._bolt5 import AsyncBolt5x0 -from neo4j._conf import PoolConfig from neo4j._meta import USER_AGENT from neo4j.exceptions import ConfigurationError @@ -72,7 +72,7 @@ def test_conn_is_not_stale(fake_socket, set_stale): async def test_extra_in_begin(fake_socket, args, kwargs, expected_fields): address = neo4j.Address(("127.0.0.1", 7687)) socket = fake_socket(address, AsyncBolt5x0.UNPACKER_CLS) - connection = AsyncBolt5x0(address, socket, PoolConfig.max_connection_lifetime) + connection = AsyncBolt5x0(address, socket, AsyncPoolConfig.max_connection_lifetime) connection.begin(*args, **kwargs) await connection.send_all() tag, is_fields = await socket.pop_message() @@ -93,7 +93,7 @@ async def test_extra_in_begin(fake_socket, args, kwargs, expected_fields): async def test_extra_in_run(fake_socket, args, kwargs, expected_fields): address = neo4j.Address(("127.0.0.1", 7687)) socket = fake_socket(address, AsyncBolt5x0.UNPACKER_CLS) - connection = AsyncBolt5x0(address, socket, PoolConfig.max_connection_lifetime) + connection = AsyncBolt5x0(address, socket, AsyncPoolConfig.max_connection_lifetime) connection.run(*args, **kwargs) await connection.send_all() tag, is_fields = await socket.pop_message() @@ -105,7 +105,7 @@ async def test_extra_in_run(fake_socket, args, kwargs, expected_fields): async def test_n_extra_in_discard(fake_socket): address = neo4j.Address(("127.0.0.1", 7687)) socket = fake_socket(address, AsyncBolt5x0.UNPACKER_CLS) - connection = AsyncBolt5x0(address, socket, PoolConfig.max_connection_lifetime) + connection = AsyncBolt5x0(address, socket, AsyncPoolConfig.max_connection_lifetime) connection.discard(n=666) await connection.send_all() tag, fields = await socket.pop_message() @@ -125,7 +125,7 @@ async def test_n_extra_in_discard(fake_socket): async def test_qid_extra_in_discard(fake_socket, test_input, expected): address = neo4j.Address(("127.0.0.1", 7687)) socket = fake_socket(address, AsyncBolt5x0.UNPACKER_CLS) - connection = AsyncBolt5x0(address, socket, PoolConfig.max_connection_lifetime) + connection = AsyncBolt5x0(address, socket, AsyncPoolConfig.max_connection_lifetime) connection.discard(qid=test_input) await connection.send_all() tag, fields = await socket.pop_message() @@ -145,7 +145,7 @@ async def test_qid_extra_in_discard(fake_socket, test_input, expected): async def test_n_and_qid_extras_in_discard(fake_socket, test_input, expected): address = neo4j.Address(("127.0.0.1", 7687)) socket = fake_socket(address, AsyncBolt5x0.UNPACKER_CLS) - connection = AsyncBolt5x0(address, socket, PoolConfig.max_connection_lifetime) + connection = AsyncBolt5x0(address, socket, AsyncPoolConfig.max_connection_lifetime) connection.discard(n=666, qid=test_input) await connection.send_all() tag, fields = await socket.pop_message() @@ -165,7 +165,7 @@ async def test_n_and_qid_extras_in_discard(fake_socket, test_input, expected): async def test_n_extra_in_pull(fake_socket, test_input, expected): address = neo4j.Address(("127.0.0.1", 7687)) socket = fake_socket(address, AsyncBolt5x0.UNPACKER_CLS) - connection = AsyncBolt5x0(address, socket, PoolConfig.max_connection_lifetime) + connection = AsyncBolt5x0(address, socket, AsyncPoolConfig.max_connection_lifetime) connection.pull(n=test_input) await connection.send_all() tag, fields = await socket.pop_message() @@ -185,7 +185,7 @@ async def test_n_extra_in_pull(fake_socket, test_input, expected): async def test_qid_extra_in_pull(fake_socket, test_input, expected): address = neo4j.Address(("127.0.0.1", 7687)) socket = fake_socket(address, AsyncBolt5x0.UNPACKER_CLS) - connection = AsyncBolt5x0(address, socket, PoolConfig.max_connection_lifetime) + connection = AsyncBolt5x0(address, socket, AsyncPoolConfig.max_connection_lifetime) connection.pull(qid=test_input) await connection.send_all() tag, fields = await socket.pop_message() @@ -198,7 +198,7 @@ async def test_qid_extra_in_pull(fake_socket, test_input, expected): async def test_n_and_qid_extras_in_pull(fake_socket): address = neo4j.Address(("127.0.0.1", 7687)) socket = fake_socket(address, AsyncBolt5x0.UNPACKER_CLS) - connection = AsyncBolt5x0(address, socket, PoolConfig.max_connection_lifetime) + connection = AsyncBolt5x0(address, socket, AsyncPoolConfig.max_connection_lifetime) connection.pull(n=666, qid=777) await connection.send_all() tag, fields = await socket.pop_message() @@ -215,7 +215,7 @@ async def test_hello_passes_routing_metadata(fake_socket_pair): unpacker_cls=AsyncBolt5x0.UNPACKER_CLS) await sockets.server.send_message(b"\x70", {"server": "Neo4j/4.4.0"}) connection = AsyncBolt5x0( - address, sockets.client, PoolConfig.max_connection_lifetime, + address, sockets.client, AsyncPoolConfig.max_connection_lifetime, routing_context={"foo": "bar"} ) await connection.hello() @@ -235,7 +235,7 @@ async def test_telemetry_message( address = neo4j.Address(("127.0.0.1", 7687)) socket = fake_socket(address, AsyncBolt5x0.UNPACKER_CLS) connection = AsyncBolt5x0( - address, socket, PoolConfig.max_connection_lifetime, + address, socket, AsyncPoolConfig.max_connection_lifetime, telemetry_disabled=driver_disabled ) if serv_enabled: @@ -272,7 +272,7 @@ async def test_hint_recv_timeout_seconds( b"\x70", {"server": "Neo4j/4.3.4", "hints": hints} ) connection = AsyncBolt5x0( - address, sockets.client, PoolConfig.max_connection_lifetime + address, sockets.client, AsyncPoolConfig.max_connection_lifetime ) with caplog.at_level(logging.INFO): await connection.hello() @@ -316,7 +316,7 @@ async def test_credentials_are_not_logged( sockets.client.settimeout = mocker.Mock() await sockets.server.send_message(b"\x70", {"server": "Neo4j/4.3.4"}) connection = AsyncBolt5x0( - address, sockets.client, PoolConfig.max_connection_lifetime, auth=auth + address, sockets.client, AsyncPoolConfig.max_connection_lifetime, auth=auth ) with caplog.at_level(logging.DEBUG): await connection.hello() @@ -334,7 +334,7 @@ async def test_credentials_are_not_logged( def test_auth_message_raises_configuration_error(message, fake_socket): address = neo4j.Address(("127.0.0.1", 7687)) connection = AsyncBolt5x0(address, fake_socket(address), - PoolConfig.max_connection_lifetime) + AsyncPoolConfig.max_connection_lifetime) with pytest.raises(ConfigurationError, match="User switching is not supported"): getattr(connection, message)() @@ -349,7 +349,7 @@ def test_auth_message_raises_configuration_error(message, fake_socket): async def test_re_auth_noop(auth, fake_socket, mocker): address = neo4j.Address(("127.0.0.1", 7687)) connection = AsyncBolt5x0(address, fake_socket(address), - PoolConfig.max_connection_lifetime, auth=auth) + AsyncPoolConfig.max_connection_lifetime, auth=auth) logon_spy = mocker.spy(connection, "logon") logoff_spy = mocker.spy(connection, "logoff") res = connection.re_auth(auth, None) @@ -374,7 +374,7 @@ async def test_re_auth_noop(auth, fake_socket, mocker): async def test_re_auth(auth1, auth2, fake_socket): address = neo4j.Address(("127.0.0.1", 7687)) connection = AsyncBolt5x0(address, fake_socket(address), - PoolConfig.max_connection_lifetime, auth=auth1) + AsyncPoolConfig.max_connection_lifetime, auth=auth1) with pytest.raises(ConfigurationError, match="User switching is not supported"): connection.re_auth(auth2, None) @@ -398,7 +398,7 @@ def test_does_not_support_notification_filters(fake_socket, method, address = neo4j.Address(("127.0.0.1", 7687)) socket = fake_socket(address, AsyncBolt5x0.UNPACKER_CLS) connection = AsyncBolt5x0(address, socket, - PoolConfig.max_connection_lifetime) + AsyncPoolConfig.max_connection_lifetime) method = getattr(connection, method) with pytest.raises(ConfigurationError, match="Notification filtering"): method(*args, **kwargs) @@ -420,7 +420,7 @@ async def test_hello_does_not_support_notification_filters( address = neo4j.Address(("127.0.0.1", 7687)) socket = fake_socket(address, AsyncBolt5x0.UNPACKER_CLS) connection = AsyncBolt5x0( - address, socket, PoolConfig.max_connection_lifetime, + address, socket, AsyncPoolConfig.max_connection_lifetime, **kwargs ) with pytest.raises(ConfigurationError, match="Notification filtering"): diff --git a/tests/unit/async_/io/test_class_bolt5x1.py b/tests/unit/async_/io/test_class_bolt5x1.py index 39342dd9d..69f0729cd 100644 --- a/tests/unit/async_/io/test_class_bolt5x1.py +++ b/tests/unit/async_/io/test_class_bolt5x1.py @@ -22,8 +22,8 @@ import neo4j import neo4j.exceptions from neo4j._api import TelemetryAPI +from neo4j._async.config import AsyncPoolConfig from neo4j._async.io._bolt5 import AsyncBolt5x1 -from neo4j._conf import PoolConfig from neo4j._meta import USER_AGENT from neo4j.exceptions import ConfigurationError @@ -77,7 +77,7 @@ async def test_extra_in_begin(fake_socket, args, kwargs, expected_fields): address = neo4j.Address(("127.0.0.1", 7687)) socket = fake_socket(address, AsyncBolt5x1.UNPACKER_CLS) connection = AsyncBolt5x1(address, socket, - PoolConfig.max_connection_lifetime) + AsyncPoolConfig.max_connection_lifetime) connection.begin(*args, **kwargs) await connection.send_all() tag, is_fields = await socket.pop_message() @@ -99,7 +99,7 @@ async def test_extra_in_run(fake_socket, args, kwargs, expected_fields): address = neo4j.Address(("127.0.0.1", 7687)) socket = fake_socket(address, AsyncBolt5x1.UNPACKER_CLS) connection = AsyncBolt5x1(address, socket, - PoolConfig.max_connection_lifetime) + AsyncPoolConfig.max_connection_lifetime) connection.run(*args, **kwargs) await connection.send_all() tag, is_fields = await socket.pop_message() @@ -112,7 +112,7 @@ async def test_n_extra_in_discard(fake_socket): address = neo4j.Address(("127.0.0.1", 7687)) socket = fake_socket(address, AsyncBolt5x1.UNPACKER_CLS) connection = AsyncBolt5x1(address, socket, - PoolConfig.max_connection_lifetime) + AsyncPoolConfig.max_connection_lifetime) connection.discard(n=666) await connection.send_all() tag, fields = await socket.pop_message() @@ -133,7 +133,7 @@ async def test_qid_extra_in_discard(fake_socket, test_input, expected): address = neo4j.Address(("127.0.0.1", 7687)) socket = fake_socket(address, AsyncBolt5x1.UNPACKER_CLS) connection = AsyncBolt5x1(address, socket, - PoolConfig.max_connection_lifetime) + AsyncPoolConfig.max_connection_lifetime) connection.discard(qid=test_input) await connection.send_all() tag, fields = await socket.pop_message() @@ -154,7 +154,7 @@ async def test_n_and_qid_extras_in_discard(fake_socket, test_input, expected): address = neo4j.Address(("127.0.0.1", 7687)) socket = fake_socket(address, AsyncBolt5x1.UNPACKER_CLS) connection = AsyncBolt5x1(address, socket, - PoolConfig.max_connection_lifetime) + AsyncPoolConfig.max_connection_lifetime) connection.discard(n=666, qid=test_input) await connection.send_all() tag, fields = await socket.pop_message() @@ -175,7 +175,7 @@ async def test_n_extra_in_pull(fake_socket, test_input, expected): address = neo4j.Address(("127.0.0.1", 7687)) socket = fake_socket(address, AsyncBolt5x1.UNPACKER_CLS) connection = AsyncBolt5x1(address, socket, - PoolConfig.max_connection_lifetime) + AsyncPoolConfig.max_connection_lifetime) connection.pull(n=test_input) await connection.send_all() tag, fields = await socket.pop_message() @@ -196,7 +196,7 @@ async def test_qid_extra_in_pull(fake_socket, test_input, expected): address = neo4j.Address(("127.0.0.1", 7687)) socket = fake_socket(address, AsyncBolt5x1.UNPACKER_CLS) connection = AsyncBolt5x1(address, socket, - PoolConfig.max_connection_lifetime) + AsyncPoolConfig.max_connection_lifetime) connection.pull(qid=test_input) await connection.send_all() tag, fields = await socket.pop_message() @@ -210,7 +210,7 @@ async def test_n_and_qid_extras_in_pull(fake_socket): address = neo4j.Address(("127.0.0.1", 7687)) socket = fake_socket(address, AsyncBolt5x1.UNPACKER_CLS) connection = AsyncBolt5x1(address, socket, - PoolConfig.max_connection_lifetime) + AsyncPoolConfig.max_connection_lifetime) connection.pull(n=666, qid=777) await connection.send_all() tag, fields = await socket.pop_message() @@ -228,7 +228,7 @@ async def test_hello_passes_routing_metadata(fake_socket_pair): await sockets.server.send_message(b"\x70", {"server": "Neo4j/4.4.0"}) await sockets.server.send_message(b"\x70", {}) connection = AsyncBolt5x1( - address, sockets.client, PoolConfig.max_connection_lifetime, + address, sockets.client, AsyncPoolConfig.max_connection_lifetime, routing_context={"foo": "bar"} ) await connection.hello() @@ -248,7 +248,7 @@ async def test_telemetry_message( address = neo4j.Address(("127.0.0.1", 7687)) socket = fake_socket(address, AsyncBolt5x1.UNPACKER_CLS) connection = AsyncBolt5x1( - address, socket, PoolConfig.max_connection_lifetime, + address, socket, AsyncPoolConfig.max_connection_lifetime, telemetry_disabled=driver_disabled ) if serv_enabled: @@ -282,7 +282,7 @@ async def test_hello_pipelines_logon(fake_socket_pair): "message": "kthxbye"} ) connection = AsyncBolt5x1( - address, sockets.client, PoolConfig.max_connection_lifetime, auth=auth + address, sockets.client, AsyncPoolConfig.max_connection_lifetime, auth=auth ) with pytest.raises(neo4j.exceptions.Neo4jError): await connection.hello() @@ -302,7 +302,7 @@ async def test_logon(fake_socket_pair): packer_cls=AsyncBolt5x1.PACKER_CLS, unpacker_cls=AsyncBolt5x1.UNPACKER_CLS) connection = AsyncBolt5x1(address, sockets.client, - PoolConfig.max_connection_lifetime, auth=auth) + AsyncPoolConfig.max_connection_lifetime, auth=auth) connection.logon() await connection.send_all() await _assert_logon_message(sockets, auth) @@ -321,7 +321,7 @@ async def test_re_auth(fake_socket_pair, mocker, static_auth): "message": "kthxbye"} ) connection = AsyncBolt5x1(address, sockets.client, - PoolConfig.max_connection_lifetime) + AsyncPoolConfig.max_connection_lifetime) connection.pool = mocker.AsyncMock() connection.re_auth(auth, auth_manager) await connection.send_all() @@ -343,7 +343,7 @@ async def test_logoff(fake_socket_pair): unpacker_cls=AsyncBolt5x1.UNPACKER_CLS) await sockets.server.send_message(b"\x70", {}) connection = AsyncBolt5x1(address, sockets.client, - PoolConfig.max_connection_lifetime) + AsyncPoolConfig.max_connection_lifetime) connection.logoff() assert not sockets.server.recv_buffer # pipelined, so no response yet await connection.send_all() @@ -379,7 +379,7 @@ async def test_hint_recv_timeout_seconds( ) await sockets.server.send_message(b"\x70", {}) connection = AsyncBolt5x1( - address, sockets.client, PoolConfig.max_connection_lifetime + address, sockets.client, AsyncPoolConfig.max_connection_lifetime ) with caplog.at_level(logging.INFO): await connection.hello() @@ -421,7 +421,7 @@ async def test_credentials_are_not_logged(auth, fake_socket_pair, caplog): await sockets.server.send_message(b"\x70", {"server": "Neo4j/4.3.4"}) await sockets.server.send_message(b"\x70", {}) connection = AsyncBolt5x1( - address, sockets.client, PoolConfig.max_connection_lifetime, auth=auth + address, sockets.client, AsyncPoolConfig.max_connection_lifetime, auth=auth ) with caplog.at_level(logging.DEBUG): await connection.hello() @@ -453,7 +453,7 @@ def test_does_not_support_notification_filters(fake_socket, method, address = neo4j.Address(("127.0.0.1", 7687)) socket = fake_socket(address, AsyncBolt5x1.UNPACKER_CLS) connection = AsyncBolt5x1(address, socket, - PoolConfig.max_connection_lifetime) + AsyncPoolConfig.max_connection_lifetime) method = getattr(connection, method) with pytest.raises(ConfigurationError, match="Notification filtering"): method(*args, **kwargs) @@ -475,7 +475,7 @@ async def test_hello_does_not_support_notification_filters( address = neo4j.Address(("127.0.0.1", 7687)) socket = fake_socket(address, AsyncBolt5x1.UNPACKER_CLS) connection = AsyncBolt5x1( - address, socket, PoolConfig.max_connection_lifetime, + address, socket, AsyncPoolConfig.max_connection_lifetime, **kwargs ) with pytest.raises(ConfigurationError, match="Notification filtering"): diff --git a/tests/unit/async_/io/test_class_bolt5x2.py b/tests/unit/async_/io/test_class_bolt5x2.py index 309f7335e..d6e6c0ea8 100644 --- a/tests/unit/async_/io/test_class_bolt5x2.py +++ b/tests/unit/async_/io/test_class_bolt5x2.py @@ -21,8 +21,8 @@ import neo4j from neo4j._api import TelemetryAPI +from neo4j._async.config import AsyncPoolConfig from neo4j._async.io._bolt5 import AsyncBolt5x2 -from neo4j._conf import PoolConfig from neo4j._meta import USER_AGENT from ...._async_compat import mark_async_test @@ -71,7 +71,7 @@ def test_conn_is_not_stale(fake_socket, set_stale): async def test_extra_in_begin(fake_socket, args, kwargs, expected_fields): address = neo4j.Address(("127.0.0.1", 7687)) socket = fake_socket(address, AsyncBolt5x2.UNPACKER_CLS) - connection = AsyncBolt5x2(address, socket, PoolConfig.max_connection_lifetime) + connection = AsyncBolt5x2(address, socket, AsyncPoolConfig.max_connection_lifetime) connection.begin(*args, **kwargs) await connection.send_all() tag, is_fields = await socket.pop_message() @@ -92,7 +92,7 @@ async def test_extra_in_begin(fake_socket, args, kwargs, expected_fields): async def test_extra_in_run(fake_socket, args, kwargs, expected_fields): address = neo4j.Address(("127.0.0.1", 7687)) socket = fake_socket(address, AsyncBolt5x2.UNPACKER_CLS) - connection = AsyncBolt5x2(address, socket, PoolConfig.max_connection_lifetime) + connection = AsyncBolt5x2(address, socket, AsyncPoolConfig.max_connection_lifetime) connection.run(*args, **kwargs) await connection.send_all() tag, is_fields = await socket.pop_message() @@ -104,7 +104,7 @@ async def test_extra_in_run(fake_socket, args, kwargs, expected_fields): async def test_n_extra_in_discard(fake_socket): address = neo4j.Address(("127.0.0.1", 7687)) socket = fake_socket(address, AsyncBolt5x2.UNPACKER_CLS) - connection = AsyncBolt5x2(address, socket, PoolConfig.max_connection_lifetime) + connection = AsyncBolt5x2(address, socket, AsyncPoolConfig.max_connection_lifetime) connection.discard(n=666) await connection.send_all() tag, fields = await socket.pop_message() @@ -124,7 +124,7 @@ async def test_n_extra_in_discard(fake_socket): async def test_qid_extra_in_discard(fake_socket, test_input, expected): address = neo4j.Address(("127.0.0.1", 7687)) socket = fake_socket(address, AsyncBolt5x2.UNPACKER_CLS) - connection = AsyncBolt5x2(address, socket, PoolConfig.max_connection_lifetime) + connection = AsyncBolt5x2(address, socket, AsyncPoolConfig.max_connection_lifetime) connection.discard(qid=test_input) await connection.send_all() tag, fields = await socket.pop_message() @@ -144,7 +144,7 @@ async def test_qid_extra_in_discard(fake_socket, test_input, expected): async def test_n_and_qid_extras_in_discard(fake_socket, test_input, expected): address = neo4j.Address(("127.0.0.1", 7687)) socket = fake_socket(address, AsyncBolt5x2.UNPACKER_CLS) - connection = AsyncBolt5x2(address, socket, PoolConfig.max_connection_lifetime) + connection = AsyncBolt5x2(address, socket, AsyncPoolConfig.max_connection_lifetime) connection.discard(n=666, qid=test_input) await connection.send_all() tag, fields = await socket.pop_message() @@ -164,7 +164,7 @@ async def test_n_and_qid_extras_in_discard(fake_socket, test_input, expected): async def test_n_extra_in_pull(fake_socket, test_input, expected): address = neo4j.Address(("127.0.0.1", 7687)) socket = fake_socket(address, AsyncBolt5x2.UNPACKER_CLS) - connection = AsyncBolt5x2(address, socket, PoolConfig.max_connection_lifetime) + connection = AsyncBolt5x2(address, socket, AsyncPoolConfig.max_connection_lifetime) connection.pull(n=test_input) await connection.send_all() tag, fields = await socket.pop_message() @@ -184,7 +184,7 @@ async def test_n_extra_in_pull(fake_socket, test_input, expected): async def test_qid_extra_in_pull(fake_socket, test_input, expected): address = neo4j.Address(("127.0.0.1", 7687)) socket = fake_socket(address, AsyncBolt5x2.UNPACKER_CLS) - connection = AsyncBolt5x2(address, socket, PoolConfig.max_connection_lifetime) + connection = AsyncBolt5x2(address, socket, AsyncPoolConfig.max_connection_lifetime) connection.pull(qid=test_input) await connection.send_all() tag, fields = await socket.pop_message() @@ -197,7 +197,7 @@ async def test_qid_extra_in_pull(fake_socket, test_input, expected): async def test_n_and_qid_extras_in_pull(fake_socket): address = neo4j.Address(("127.0.0.1", 7687)) socket = fake_socket(address, AsyncBolt5x2.UNPACKER_CLS) - connection = AsyncBolt5x2(address, socket, PoolConfig.max_connection_lifetime) + connection = AsyncBolt5x2(address, socket, AsyncPoolConfig.max_connection_lifetime) connection.pull(n=666, qid=777) await connection.send_all() tag, fields = await socket.pop_message() @@ -215,7 +215,7 @@ async def test_hello_passes_routing_metadata(fake_socket_pair): await sockets.server.send_message(b"\x70", {"server": "Neo4j/4.4.0"}) await sockets.server.send_message(b"\x70", {}) connection = AsyncBolt5x2( - address, sockets.client, PoolConfig.max_connection_lifetime, + address, sockets.client, AsyncPoolConfig.max_connection_lifetime, routing_context={"foo": "bar"} ) await connection.hello() @@ -235,7 +235,7 @@ async def test_telemetry_message( address = neo4j.Address(("127.0.0.1", 7687)) socket = fake_socket(address, AsyncBolt5x2.UNPACKER_CLS) connection = AsyncBolt5x2( - address, socket, PoolConfig.max_connection_lifetime, + address, socket, AsyncPoolConfig.max_connection_lifetime, telemetry_disabled=driver_disabled ) if serv_enabled: @@ -269,7 +269,7 @@ async def test_hello_pipelines_logon(fake_socket_pair): "message": "kthxbye"} ) connection = AsyncBolt5x2( - address, sockets.client, PoolConfig.max_connection_lifetime, auth=auth + address, sockets.client, AsyncPoolConfig.max_connection_lifetime, auth=auth ) with pytest.raises(neo4j.exceptions.Neo4jError): await connection.hello() @@ -289,7 +289,7 @@ async def test_logon(fake_socket_pair): packer_cls=AsyncBolt5x2.PACKER_CLS, unpacker_cls=AsyncBolt5x2.UNPACKER_CLS) connection = AsyncBolt5x2(address, sockets.client, - PoolConfig.max_connection_lifetime, auth=auth) + AsyncPoolConfig.max_connection_lifetime, auth=auth) connection.logon() await connection.send_all() await _assert_logon_message(sockets, auth) @@ -308,7 +308,7 @@ async def test_re_auth(fake_socket_pair, mocker, static_auth): "message": "kthxbye"} ) connection = AsyncBolt5x2(address, sockets.client, - PoolConfig.max_connection_lifetime) + AsyncPoolConfig.max_connection_lifetime) connection.pool = mocker.AsyncMock() connection.re_auth(auth, auth_manager) await connection.send_all() @@ -330,7 +330,7 @@ async def test_logoff(fake_socket_pair): unpacker_cls=AsyncBolt5x2.UNPACKER_CLS) await sockets.server.send_message(b"\x70", {}) connection = AsyncBolt5x2(address, sockets.client, - PoolConfig.max_connection_lifetime) + AsyncPoolConfig.max_connection_lifetime) connection.logoff() assert not sockets.server.recv_buffer # pipelined, so no response yet await connection.send_all() @@ -366,7 +366,7 @@ async def test_hint_recv_timeout_seconds( ) await sockets.server.send_message(b"\x70", {}) connection = AsyncBolt5x2( - address, sockets.client, PoolConfig.max_connection_lifetime + address, sockets.client, AsyncPoolConfig.max_connection_lifetime ) with caplog.at_level(logging.INFO): await connection.hello() @@ -408,7 +408,7 @@ async def test_credentials_are_not_logged(auth, fake_socket_pair, caplog): await sockets.server.send_message(b"\x70", {"server": "Neo4j/4.3.4"}) await sockets.server.send_message(b"\x70", {}) connection = AsyncBolt5x2( - address, sockets.client, PoolConfig.max_connection_lifetime, auth=auth + address, sockets.client, AsyncPoolConfig.max_connection_lifetime, auth=auth ) with caplog.at_level(logging.DEBUG): await connection.hello() @@ -448,7 +448,7 @@ async def test_supports_notification_filters( address = neo4j.Address(("127.0.0.1", 7687)) socket = fake_socket(address, AsyncBolt5x2.UNPACKER_CLS) connection = AsyncBolt5x2( - address, socket, PoolConfig.max_connection_lifetime, + address, socket, AsyncPoolConfig.max_connection_lifetime, notifications_min_severity=cls_min_sev, notifications_disabled_categories=cls_dis_cats ) @@ -482,7 +482,7 @@ async def test_hello_supports_notification_filters( await sockets.server.send_message(b"\x70", {"server": "Neo4j/1.2.3"}) await sockets.server.send_message(b"\x70", {}) connection = AsyncBolt5x2( - address, sockets.client, PoolConfig.max_connection_lifetime, + address, sockets.client, AsyncPoolConfig.max_connection_lifetime, notifications_min_severity=min_sev, notifications_disabled_categories=dis_cats ) diff --git a/tests/unit/async_/io/test_class_bolt5x3.py b/tests/unit/async_/io/test_class_bolt5x3.py index bf5e63080..fa02a8542 100644 --- a/tests/unit/async_/io/test_class_bolt5x3.py +++ b/tests/unit/async_/io/test_class_bolt5x3.py @@ -21,8 +21,8 @@ import neo4j from neo4j._api import TelemetryAPI +from neo4j._async.config import AsyncPoolConfig from neo4j._async.io._bolt5 import AsyncBolt5x3 -from neo4j._conf import PoolConfig from neo4j._meta import ( BOLT_AGENT_DICT, USER_AGENT, @@ -74,7 +74,7 @@ def test_conn_is_not_stale(fake_socket, set_stale): async def test_extra_in_begin(fake_socket, args, kwargs, expected_fields): address = neo4j.Address(("127.0.0.1", 7687)) socket = fake_socket(address, AsyncBolt5x3.UNPACKER_CLS) - connection = AsyncBolt5x3(address, socket, PoolConfig.max_connection_lifetime) + connection = AsyncBolt5x3(address, socket, AsyncPoolConfig.max_connection_lifetime) connection.begin(*args, **kwargs) await connection.send_all() tag, is_fields = await socket.pop_message() @@ -95,7 +95,7 @@ async def test_extra_in_begin(fake_socket, args, kwargs, expected_fields): async def test_extra_in_run(fake_socket, args, kwargs, expected_fields): address = neo4j.Address(("127.0.0.1", 7687)) socket = fake_socket(address, AsyncBolt5x3.UNPACKER_CLS) - connection = AsyncBolt5x3(address, socket, PoolConfig.max_connection_lifetime) + connection = AsyncBolt5x3(address, socket, AsyncPoolConfig.max_connection_lifetime) connection.run(*args, **kwargs) await connection.send_all() tag, is_fields = await socket.pop_message() @@ -107,7 +107,7 @@ async def test_extra_in_run(fake_socket, args, kwargs, expected_fields): async def test_n_extra_in_discard(fake_socket): address = neo4j.Address(("127.0.0.1", 7687)) socket = fake_socket(address, AsyncBolt5x3.UNPACKER_CLS) - connection = AsyncBolt5x3(address, socket, PoolConfig.max_connection_lifetime) + connection = AsyncBolt5x3(address, socket, AsyncPoolConfig.max_connection_lifetime) connection.discard(n=666) await connection.send_all() tag, fields = await socket.pop_message() @@ -127,7 +127,7 @@ async def test_n_extra_in_discard(fake_socket): async def test_qid_extra_in_discard(fake_socket, test_input, expected): address = neo4j.Address(("127.0.0.1", 7687)) socket = fake_socket(address, AsyncBolt5x3.UNPACKER_CLS) - connection = AsyncBolt5x3(address, socket, PoolConfig.max_connection_lifetime) + connection = AsyncBolt5x3(address, socket, AsyncPoolConfig.max_connection_lifetime) connection.discard(qid=test_input) await connection.send_all() tag, fields = await socket.pop_message() @@ -147,7 +147,7 @@ async def test_qid_extra_in_discard(fake_socket, test_input, expected): async def test_n_and_qid_extras_in_discard(fake_socket, test_input, expected): address = neo4j.Address(("127.0.0.1", 7687)) socket = fake_socket(address, AsyncBolt5x3.UNPACKER_CLS) - connection = AsyncBolt5x3(address, socket, PoolConfig.max_connection_lifetime) + connection = AsyncBolt5x3(address, socket, AsyncPoolConfig.max_connection_lifetime) connection.discard(n=666, qid=test_input) await connection.send_all() tag, fields = await socket.pop_message() @@ -167,7 +167,7 @@ async def test_n_and_qid_extras_in_discard(fake_socket, test_input, expected): async def test_n_extra_in_pull(fake_socket, test_input, expected): address = neo4j.Address(("127.0.0.1", 7687)) socket = fake_socket(address, AsyncBolt5x3.UNPACKER_CLS) - connection = AsyncBolt5x3(address, socket, PoolConfig.max_connection_lifetime) + connection = AsyncBolt5x3(address, socket, AsyncPoolConfig.max_connection_lifetime) connection.pull(n=test_input) await connection.send_all() tag, fields = await socket.pop_message() @@ -187,7 +187,7 @@ async def test_n_extra_in_pull(fake_socket, test_input, expected): async def test_qid_extra_in_pull(fake_socket, test_input, expected): address = neo4j.Address(("127.0.0.1", 7687)) socket = fake_socket(address, AsyncBolt5x3.UNPACKER_CLS) - connection = AsyncBolt5x3(address, socket, PoolConfig.max_connection_lifetime) + connection = AsyncBolt5x3(address, socket, AsyncPoolConfig.max_connection_lifetime) connection.pull(qid=test_input) await connection.send_all() tag, fields = await socket.pop_message() @@ -200,7 +200,7 @@ async def test_qid_extra_in_pull(fake_socket, test_input, expected): async def test_n_and_qid_extras_in_pull(fake_socket): address = neo4j.Address(("127.0.0.1", 7687)) socket = fake_socket(address, AsyncBolt5x3.UNPACKER_CLS) - connection = AsyncBolt5x3(address, socket, PoolConfig.max_connection_lifetime) + connection = AsyncBolt5x3(address, socket, AsyncPoolConfig.max_connection_lifetime) connection.pull(n=666, qid=777) await connection.send_all() tag, fields = await socket.pop_message() @@ -218,7 +218,7 @@ async def test_hello_passes_routing_metadata(fake_socket_pair): await sockets.server.send_message(b"\x70", {"server": "Neo4j/4.4.0"}) await sockets.server.send_message(b"\x70", {}) connection = AsyncBolt5x3( - address, sockets.client, PoolConfig.max_connection_lifetime, + address, sockets.client, AsyncPoolConfig.max_connection_lifetime, routing_context={"foo": "bar"} ) await connection.hello() @@ -238,7 +238,7 @@ async def test_telemetry_message( address = neo4j.Address(("127.0.0.1", 7687)) socket = fake_socket(address, AsyncBolt5x3.UNPACKER_CLS) connection = AsyncBolt5x3( - address, socket, PoolConfig.max_connection_lifetime, + address, socket, AsyncPoolConfig.max_connection_lifetime, telemetry_disabled=driver_disabled ) if serv_enabled: @@ -276,7 +276,7 @@ async def test_hint_recv_timeout_seconds( ) await sockets.server.send_message(b"\x70", {}) connection = AsyncBolt5x3( - address, sockets.client, PoolConfig.max_connection_lifetime + address, sockets.client, AsyncPoolConfig.max_connection_lifetime ) with caplog.at_level(logging.INFO): await connection.hello() @@ -318,7 +318,7 @@ async def test_credentials_are_not_logged(auth, fake_socket_pair, caplog): await sockets.server.send_message(b"\x70", {"server": "Neo4j/4.3.4"}) await sockets.server.send_message(b"\x70", {}) connection = AsyncBolt5x3( - address, sockets.client, PoolConfig.max_connection_lifetime, auth=auth + address, sockets.client, AsyncPoolConfig.max_connection_lifetime, auth=auth ) with caplog.at_level(logging.DEBUG): await connection.hello() @@ -359,7 +359,7 @@ async def test_supports_notification_filters( address = neo4j.Address(("127.0.0.1", 7687)) socket = fake_socket(address, AsyncBolt5x3.UNPACKER_CLS) connection = AsyncBolt5x3( - address, socket, PoolConfig.max_connection_lifetime, + address, socket, AsyncPoolConfig.max_connection_lifetime, notifications_min_severity=cls_min_sev, notifications_disabled_categories=cls_dis_cats ) @@ -393,7 +393,7 @@ async def test_hello_supports_notification_filters( await sockets.server.send_message(b"\x70", {"server": "Neo4j/1.2.3"}) await sockets.server.send_message(b"\x70", {}) connection = AsyncBolt5x3( - address, sockets.client, PoolConfig.max_connection_lifetime, + address, sockets.client, AsyncPoolConfig.max_connection_lifetime, notifications_min_severity=min_sev, notifications_disabled_categories=dis_cats ) diff --git a/tests/unit/async_/io/test_class_bolt5x4.py b/tests/unit/async_/io/test_class_bolt5x4.py index 043e68f47..a71b4efdd 100644 --- a/tests/unit/async_/io/test_class_bolt5x4.py +++ b/tests/unit/async_/io/test_class_bolt5x4.py @@ -21,8 +21,8 @@ import neo4j from neo4j._api import TelemetryAPI +from neo4j._async.config import AsyncPoolConfig from neo4j._async.io._bolt5 import AsyncBolt5x4 -from neo4j._conf import PoolConfig from neo4j._meta import ( BOLT_AGENT_DICT, USER_AGENT, @@ -74,7 +74,7 @@ def test_conn_is_not_stale(fake_socket, set_stale): async def test_extra_in_begin(fake_socket, args, kwargs, expected_fields): address = neo4j.Address(("127.0.0.1", 7687)) socket = fake_socket(address, AsyncBolt5x4.UNPACKER_CLS) - connection = AsyncBolt5x4(address, socket, PoolConfig.max_connection_lifetime) + connection = AsyncBolt5x4(address, socket, AsyncPoolConfig.max_connection_lifetime) connection.begin(*args, **kwargs) await connection.send_all() tag, is_fields = await socket.pop_message() @@ -95,7 +95,7 @@ async def test_extra_in_begin(fake_socket, args, kwargs, expected_fields): async def test_extra_in_run(fake_socket, args, kwargs, expected_fields): address = neo4j.Address(("127.0.0.1", 7687)) socket = fake_socket(address, AsyncBolt5x4.UNPACKER_CLS) - connection = AsyncBolt5x4(address, socket, PoolConfig.max_connection_lifetime) + connection = AsyncBolt5x4(address, socket, AsyncPoolConfig.max_connection_lifetime) connection.run(*args, **kwargs) await connection.send_all() tag, is_fields = await socket.pop_message() @@ -107,7 +107,7 @@ async def test_extra_in_run(fake_socket, args, kwargs, expected_fields): async def test_n_extra_in_discard(fake_socket): address = neo4j.Address(("127.0.0.1", 7687)) socket = fake_socket(address, AsyncBolt5x4.UNPACKER_CLS) - connection = AsyncBolt5x4(address, socket, PoolConfig.max_connection_lifetime) + connection = AsyncBolt5x4(address, socket, AsyncPoolConfig.max_connection_lifetime) connection.discard(n=666) await connection.send_all() tag, fields = await socket.pop_message() @@ -127,7 +127,7 @@ async def test_n_extra_in_discard(fake_socket): async def test_qid_extra_in_discard(fake_socket, test_input, expected): address = neo4j.Address(("127.0.0.1", 7687)) socket = fake_socket(address, AsyncBolt5x4.UNPACKER_CLS) - connection = AsyncBolt5x4(address, socket, PoolConfig.max_connection_lifetime) + connection = AsyncBolt5x4(address, socket, AsyncPoolConfig.max_connection_lifetime) connection.discard(qid=test_input) await connection.send_all() tag, fields = await socket.pop_message() @@ -147,7 +147,7 @@ async def test_qid_extra_in_discard(fake_socket, test_input, expected): async def test_n_and_qid_extras_in_discard(fake_socket, test_input, expected): address = neo4j.Address(("127.0.0.1", 7687)) socket = fake_socket(address, AsyncBolt5x4.UNPACKER_CLS) - connection = AsyncBolt5x4(address, socket, PoolConfig.max_connection_lifetime) + connection = AsyncBolt5x4(address, socket, AsyncPoolConfig.max_connection_lifetime) connection.discard(n=666, qid=test_input) await connection.send_all() tag, fields = await socket.pop_message() @@ -167,7 +167,7 @@ async def test_n_and_qid_extras_in_discard(fake_socket, test_input, expected): async def test_n_extra_in_pull(fake_socket, test_input, expected): address = neo4j.Address(("127.0.0.1", 7687)) socket = fake_socket(address, AsyncBolt5x4.UNPACKER_CLS) - connection = AsyncBolt5x4(address, socket, PoolConfig.max_connection_lifetime) + connection = AsyncBolt5x4(address, socket, AsyncPoolConfig.max_connection_lifetime) connection.pull(n=test_input) await connection.send_all() tag, fields = await socket.pop_message() @@ -187,7 +187,7 @@ async def test_n_extra_in_pull(fake_socket, test_input, expected): async def test_qid_extra_in_pull(fake_socket, test_input, expected): address = neo4j.Address(("127.0.0.1", 7687)) socket = fake_socket(address, AsyncBolt5x4.UNPACKER_CLS) - connection = AsyncBolt5x4(address, socket, PoolConfig.max_connection_lifetime) + connection = AsyncBolt5x4(address, socket, AsyncPoolConfig.max_connection_lifetime) connection.pull(qid=test_input) await connection.send_all() tag, fields = await socket.pop_message() @@ -200,7 +200,7 @@ async def test_qid_extra_in_pull(fake_socket, test_input, expected): async def test_n_and_qid_extras_in_pull(fake_socket): address = neo4j.Address(("127.0.0.1", 7687)) socket = fake_socket(address, AsyncBolt5x4.UNPACKER_CLS) - connection = AsyncBolt5x4(address, socket, PoolConfig.max_connection_lifetime) + connection = AsyncBolt5x4(address, socket, AsyncPoolConfig.max_connection_lifetime) connection.pull(n=666, qid=777) await connection.send_all() tag, fields = await socket.pop_message() @@ -218,7 +218,7 @@ async def test_hello_passes_routing_metadata(fake_socket_pair): await sockets.server.send_message(b"\x70", {"server": "Neo4j/4.4.0"}) await sockets.server.send_message(b"\x70", {}) connection = AsyncBolt5x4( - address, sockets.client, PoolConfig.max_connection_lifetime, + address, sockets.client, AsyncPoolConfig.max_connection_lifetime, routing_context={"foo": "bar"} ) await connection.hello() @@ -238,7 +238,7 @@ async def test_telemetry_message( address = neo4j.Address(("127.0.0.1", 7687)) socket = fake_socket(address, AsyncBolt5x4.UNPACKER_CLS) connection = AsyncBolt5x4( - address, socket, PoolConfig.max_connection_lifetime, + address, socket, AsyncPoolConfig.max_connection_lifetime, telemetry_disabled=driver_disabled ) if serv_enabled: @@ -281,7 +281,7 @@ async def test_hint_recv_timeout_seconds( ) await sockets.server.send_message(b"\x70", {}) connection = AsyncBolt5x4( - address, sockets.client, PoolConfig.max_connection_lifetime + address, sockets.client, AsyncPoolConfig.max_connection_lifetime ) with caplog.at_level(logging.INFO): await connection.hello() @@ -323,7 +323,7 @@ async def test_credentials_are_not_logged(auth, fake_socket_pair, caplog): await sockets.server.send_message(b"\x70", {"server": "Neo4j/4.3.4"}) await sockets.server.send_message(b"\x70", {}) connection = AsyncBolt5x4( - address, sockets.client, PoolConfig.max_connection_lifetime, auth=auth + address, sockets.client, AsyncPoolConfig.max_connection_lifetime, auth=auth ) with caplog.at_level(logging.DEBUG): await connection.hello() @@ -364,7 +364,7 @@ async def test_supports_notification_filters( address = neo4j.Address(("127.0.0.1", 7687)) socket = fake_socket(address, AsyncBolt5x4.UNPACKER_CLS) connection = AsyncBolt5x4( - address, socket, PoolConfig.max_connection_lifetime, + address, socket, AsyncPoolConfig.max_connection_lifetime, notifications_min_severity=cls_min_sev, notifications_disabled_categories=cls_dis_cats ) @@ -398,7 +398,7 @@ async def test_hello_supports_notification_filters( await sockets.server.send_message(b"\x70", {"server": "Neo4j/1.2.3"}) await sockets.server.send_message(b"\x70", {}) connection = AsyncBolt5x4( - address, sockets.client, PoolConfig.max_connection_lifetime, + address, sockets.client, AsyncPoolConfig.max_connection_lifetime, notifications_min_severity=min_sev, notifications_disabled_categories=dis_cats ) diff --git a/tests/unit/async_/io/test_direct.py b/tests/unit/async_/io/test_direct.py index 24a3d26ec..e7f77c6ca 100644 --- a/tests/unit/async_/io/test_direct.py +++ b/tests/unit/async_/io/test_direct.py @@ -17,11 +17,11 @@ import pytest import neo4j +from neo4j._async.config import AsyncPoolConfig from neo4j._async.io import AsyncBolt from neo4j._async.io._pool import AsyncIOPool from neo4j._conf import ( Config, - PoolConfig, WorkspaceConfig, ) from neo4j._deadline import Deadline @@ -40,7 +40,9 @@ class AsyncFakeBoltPool(AsyncIOPool): def __init__(self, connection_gen, address, *, auth=None, **config): self.buffered_connection_mocks = [] config["auth"] = static_auth(None) - self.pool_config, self.workspace_config = Config.consume_chain(config, PoolConfig, WorkspaceConfig) + self.pool_config, self.workspace_config = Config.consume_chain( + config, AsyncPoolConfig, WorkspaceConfig + ) if config: raise ValueError("Unexpected config keys: %s" % ", ".join(config.keys())) diff --git a/tests/unit/async_/io/test_neo4j_pool.py b/tests/unit/async_/io/test_neo4j_pool.py index f14b1dff4..8e4152fb0 100644 --- a/tests/unit/async_/io/test_neo4j_pool.py +++ b/tests/unit/async_/io/test_neo4j_pool.py @@ -20,17 +20,16 @@ import pytest from neo4j import ( - PreviewWarning, READ_ACCESS, WRITE_ACCESS, ) +from neo4j._async.config import AsyncPoolConfig from neo4j._async.io import ( AsyncBolt, AsyncNeo4jPool, ) from neo4j._async_compat.util import AsyncUtil from neo4j._conf import ( - PoolConfig, RoutingConfig, WorkspaceConfig, ) @@ -106,7 +105,7 @@ def opener(custom_routing_opener): def _pool_config(): - pool_config = PoolConfig() + pool_config = AsyncPoolConfig() pool_config.auth = _auth_manager(("user", "pass")) return pool_config @@ -501,7 +500,7 @@ async def test__acquire_new_later_without_room(opener): async def test_passes_pool_config_to_connection(mocker): bolt_mock = mocker.patch.object(AsyncBolt, "open", autospec=True) - pool_config = PoolConfig() + pool_config = AsyncPoolConfig() workspace_config = WorkspaceConfig() pool = AsyncNeo4jPool.open( mocker.Mock, pool_config=pool_config, workspace_config=workspace_config diff --git a/tests/unit/async_/test_auth_manager.py b/tests/unit/async_/test_auth_management.py similarity index 71% rename from tests/unit/async_/test_auth_manager.py rename to tests/unit/async_/test_auth_management.py index 13b836171..ac68e58b1 100644 --- a/tests/unit/async_/test_auth_manager.py +++ b/tests/unit/async_/test_auth_management.py @@ -30,6 +30,10 @@ from neo4j.auth_management import ( AsyncAuthManager, AsyncAuthManagers, + AsyncClientCertificateProvider, + AsyncClientCertificateProviders, + AsyncRotatingClientCertificateProvider, + ClientCertificate, ExpiringAuth, ) from neo4j.exceptions import Neo4jError @@ -237,3 +241,88 @@ async def _test_manager( else: assert await manager.get_auth() is auth1 provider.assert_not_called() + + +@pytest.fixture +def client_cert_factory() -> t.Callable[[], ClientCertificate]: + i = 0 + + def factory() -> ClientCertificate: + with pytest.warns(PreviewWarning, match="Mutual TLS"): + return ClientCertificate(f"cert{i}") + + return factory + + +@copy_signature(AsyncClientCertificateProviders.static) +def static_cert_provider(*args, **kwargs): + with pytest.warns(PreviewWarning, match="Mutual TLS"): + return AsyncClientCertificateProviders.static(*args, **kwargs) + + +@copy_signature(AsyncRotatingClientCertificateProvider) +def rotating_cert_provider_direct(*args, **kwargs): + with pytest.warns(PreviewWarning, match="Mutual TLS"): + return AsyncRotatingClientCertificateProvider(*args, **kwargs) + + +@copy_signature(AsyncClientCertificateProviders.rotating) +def rotating_cert_provider(*args, **kwargs): + with pytest.warns(PreviewWarning, match="Mutual TLS"): + return AsyncClientCertificateProviders.rotating(*args, **kwargs) + + +@mark_async_test +async def test_static_client_cert_provider(client_cert_factory) -> None: + cert1: ClientCertificate = client_cert_factory() + provider: AsyncClientCertificateProvider = static_cert_provider(cert1) + + assert await provider.get_certificate() is cert1 + for _ in range(10): + assert await provider.get_certificate() is None + + +if t.TYPE_CHECKING: + # Tests for type checker only. No need to run the test. + + async def test_rotating_client_cert_provider_type_init( + client_cert_factory + ) -> None: + cert1: ClientCertificate = client_cert_factory() + provider: AsyncRotatingClientCertificateProvider = \ + rotating_cert_provider_direct(cert1) + _: AsyncClientCertificateProvider = provider + + + async def test_rotating_client_cert_provider_type_factory( + client_cert_factory + ) -> None: + cert1: ClientCertificate = client_cert_factory() + provider: AsyncRotatingClientCertificateProvider = \ + rotating_cert_provider(cert1) + _: AsyncClientCertificateProvider = provider + + +@pytest.mark.parametrize( + "factory", (rotating_cert_provider, rotating_cert_provider_direct) +) +@mark_async_test +async def test_rotating_client_cert_provider( + factory: t.Callable[[ClientCertificate], + AsyncRotatingClientCertificateProvider], + client_cert_factory +) -> None: + cert1: ClientCertificate = client_cert_factory() + cert2: ClientCertificate = client_cert_factory() + assert cert1 is not cert2 # sanity check + provider: AsyncRotatingClientCertificateProvider = factory(cert1) + + assert await provider.get_certificate() is cert1 + for _ in range(10): + assert await provider.get_certificate() is None + + await provider.update_certificate(cert2) + + assert await provider.get_certificate() is cert2 + for _ in range(10): + assert await provider.get_certificate() is None diff --git a/tests/unit/async_/test_conf.py b/tests/unit/async_/test_conf.py new file mode 100644 index 000000000..e1d4e178c --- /dev/null +++ b/tests/unit/async_/test_conf.py @@ -0,0 +1,442 @@ +# Copyright (c) "Neo4j" +# Neo4j Sweden AB [https://neo4j.com] +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import ssl + +import pytest + +from neo4j import ( + PreviewWarning, + TrustAll, + TrustCustomCAs, + TrustSystemCAs, +) +from neo4j._async.config import AsyncPoolConfig +from neo4j._conf import ( + Config, + SessionConfig, +) +from neo4j.api import ( + TRUST_ALL_CERTIFICATES, + TRUST_SYSTEM_CA_SIGNED_CERTIFICATES, +) +from neo4j.auth_management import ( + AsyncClientCertificateProviders, + ClientCertificate, +) +from neo4j.debug import watch +from neo4j.exceptions import ConfigurationError + +from ..._async_compat import mark_async_test +from ..common.test_conf import test_session_config + + +# python -m pytest tests/unit/test_conf.py -s -v + +watch("neo4j") + +test_pool_config = { + "connection_timeout": 30.0, + "keep_alive": True, + "max_connection_lifetime": 3600, + "liveness_check_timeout": None, + "max_connection_pool_size": 100, + "resolver": None, + "encrypted": False, + "user_agent": "test", + "trusted_certificates": TrustSystemCAs(), + "client_certificate": None, + "ssl_context": None, + "auth": None, + "notifications_min_severity": None, + "notifications_disabled_categories": None, + "telemetry_disabled": False, +} + + +def test_pool_config_consume(): + + test_config = dict(test_pool_config) + + consumed_pool_config = AsyncPoolConfig.consume(test_config) + + assert isinstance(consumed_pool_config, AsyncPoolConfig) + + assert len(test_config) == 0 + + for key in test_pool_config.keys(): + assert consumed_pool_config[key] == test_pool_config[key] + + for key in consumed_pool_config.keys(): + assert test_pool_config[key] == consumed_pool_config[key] + + assert len(consumed_pool_config) == len(test_pool_config) + + +def test_pool_config_consume_default_values(): + + test_config = {} + + consumed_pool_config = AsyncPoolConfig.consume(test_config) + + assert isinstance(consumed_pool_config, AsyncPoolConfig) + + assert len(test_config) == 0 + + consumed_pool_config.keep_alive = "changed" + + assert AsyncPoolConfig.keep_alive != consumed_pool_config.keep_alive + + +def test_pool_config_consume_key_not_valid(): + + test_config = dict(test_pool_config) + + test_config["not_valid_key"] = "test" + + with pytest.raises(ConfigurationError) as error: + AsyncPoolConfig.consume(test_config) + + error.match("Unexpected config keys: not_valid_key") + + +def test_pool_config_set_value(): + + test_config = dict(test_pool_config) + + consumed_pool_config = AsyncPoolConfig.consume(test_config) + + assert consumed_pool_config.get("encrypted") is False + assert consumed_pool_config["encrypted"] is False + assert consumed_pool_config.encrypted is False + + consumed_pool_config.encrypted = "test" + + assert consumed_pool_config.get("encrypted") == "test" + assert consumed_pool_config["encrypted"] == "test" + assert consumed_pool_config.encrypted == "test" + + consumed_pool_config.not_valid_key = "test" # Use consume functions + + +def test_pool_config_consume_and_then_consume_again(): + test_config = dict(test_pool_config) + consumed_pool_config = AsyncPoolConfig.consume(test_config) + assert consumed_pool_config.encrypted is False + consumed_pool_config.encrypted = "test" + + with pytest.raises(AttributeError): + consumed_pool_config = AsyncPoolConfig.consume(consumed_pool_config) + + consumed_pool_config = AsyncPoolConfig.consume(dict(consumed_pool_config.items())) + consumed_pool_config = AsyncPoolConfig.consume(dict(consumed_pool_config.items())) + + assert consumed_pool_config.encrypted == "test" + + +@pytest.mark.parametrize( + ("value_trust", "expected_trusted_certificates_cls"), + ( + (TRUST_ALL_CERTIFICATES, TrustAll), + (TRUST_SYSTEM_CA_SIGNED_CERTIFICATES, TrustSystemCAs), + ) +) +def test_pool_config_deprecated_trust_config( + value_trust, expected_trusted_certificates_cls +): + with pytest.warns(DeprecationWarning, match="trust.*trusted_certificates"): + consumed_pool_config = AsyncPoolConfig.consume({"trust": value_trust}) + assert isinstance(consumed_pool_config.trusted_certificates, + expected_trusted_certificates_cls) + assert not hasattr(consumed_pool_config, "trust") + + +@pytest.mark.parametrize("value_trust", ( + TRUST_ALL_CERTIFICATES, TRUST_SYSTEM_CA_SIGNED_CERTIFICATES +)) +@pytest.mark.parametrize("trusted_certificates", ( + TrustSystemCAs(), TrustAll(), TrustCustomCAs("foo"), + TrustCustomCAs("foo", "bar") +)) +def test_pool_config_deprecated_and_new_trust_config(value_trust, + trusted_certificates): + with pytest.raises(ConfigurationError, + match="trusted_certificates.*trust"): + AsyncPoolConfig.consume({ + "trust": value_trust, + "trusted_certificates": trusted_certificates} + ) + + +def test_config_consume_chain(): + test_config = {} + + test_config.update(test_pool_config) + + test_config.update(test_session_config) + + consumed_pool_config, consumed_session_config = Config.consume_chain( + test_config, AsyncPoolConfig, SessionConfig + ) + + assert isinstance(consumed_pool_config, AsyncPoolConfig) + assert isinstance(consumed_session_config, SessionConfig) + + assert len(test_config) == 0 + + for key, val in test_pool_config.items(): + assert consumed_pool_config[key] == val + + for key, val in consumed_pool_config.items(): + assert test_pool_config[key] == val + + assert len(consumed_pool_config) == len(test_pool_config) + + assert len(consumed_session_config) == len(test_session_config) + + +@pytest.mark.parametrize("config", ( + {}, + {"encrypted": False}, + {"trusted_certificates": TrustSystemCAs()}, + {"trusted_certificates": TrustAll()}, + {"trusted_certificates": TrustCustomCAs("foo", "bar")}, +)) +@mark_async_test +async def test_no_ssl_mock(config, mocker): + ssl_context_mock = mocker.patch("ssl.SSLContext", autospec=True) + pool_config = AsyncPoolConfig.consume(config) + assert pool_config.encrypted is False + assert await pool_config.get_ssl_context() is None + ssl_context_mock.assert_not_called() + # test caching + assert await pool_config.get_ssl_context() is None + ssl_context_mock.assert_not_called() + + +@pytest.mark.parametrize("config", ( + {"encrypted": True}, + {"encrypted": True, "trusted_certificates": TrustSystemCAs()}, +)) +@mark_async_test +async def test_trust_system_cas_mock(config, mocker): + ssl_context_mock = mocker.patch("ssl.SSLContext", autospec=True) + pool_config = AsyncPoolConfig.consume(config) + assert pool_config.encrypted is True + ssl_context = await pool_config.get_ssl_context() + _assert_mock_tls_1_2(ssl_context_mock) + assert ssl_context.minimum_version == ssl.TLSVersion.TLSv1_2 + ssl_context_mock.return_value.load_default_certs.assert_called_once_with() + ssl_context_mock.return_value.load_verify_locations.assert_not_called() + assert ssl_context.check_hostname is True + assert ssl_context.verify_mode == ssl.CERT_REQUIRED + # test caching + ssl_context_mock.reset_mock() + assert await pool_config.get_ssl_context() is ssl_context + ssl_context_mock.assert_not_called() + + +@pytest.mark.parametrize("config", ( + {"encrypted": True, "trusted_certificates": TrustCustomCAs("foo", "bar")}, + {"encrypted": True, "trusted_certificates": TrustCustomCAs()}, +)) +@mark_async_test +async def test_trust_custom_cas_mock(config, mocker): + ssl_context_mock = mocker.patch("ssl.SSLContext", autospec=True) + certs = config["trusted_certificates"].certs + pool_config = AsyncPoolConfig.consume(config) + assert pool_config.encrypted is True + ssl_context = await pool_config.get_ssl_context() + _assert_mock_tls_1_2(ssl_context_mock) + assert ssl_context.minimum_version == ssl.TLSVersion.TLSv1_2 + ssl_context_mock.return_value.load_default_certs.assert_not_called() + assert ( + ssl_context_mock.return_value.load_verify_locations.call_args_list + == [((cert,), {}) for cert in certs] + ) + assert ssl_context.check_hostname is True + assert ssl_context.verify_mode == ssl.CERT_REQUIRED + # test caching + assert await pool_config.get_ssl_context() is ssl_context + + +@pytest.mark.parametrize("config", ( + {"encrypted": True, "trusted_certificates": TrustAll()}, +)) +@mark_async_test +async def test_trust_all_mock(config, mocker): + ssl_context_mock = mocker.patch("ssl.SSLContext", autospec=True) + pool_config = AsyncPoolConfig.consume(config) + assert pool_config.encrypted is True + ssl_context = await pool_config.get_ssl_context() + _assert_mock_tls_1_2(ssl_context_mock) + assert ssl_context.minimum_version == ssl.TLSVersion.TLSv1_2 + ssl_context_mock.return_value.load_default_certs.assert_not_called() + ssl_context_mock.return_value.load_verify_locations.assert_not_called() + assert ssl_context.check_hostname is False + assert ssl_context.verify_mode is ssl.CERT_NONE + # test caching + ssl_context_mock.reset_mock() + assert await pool_config.get_ssl_context() is ssl_context + ssl_context_mock.assert_not_called() + + +def _assert_mock_tls_1_2(mock): + mock.assert_called_once_with(ssl.PROTOCOL_TLS_CLIENT) + assert mock.return_value.minimum_version == ssl.TLSVersion.TLSv1_2 + + +@pytest.mark.parametrize("config", ( + {}, + {"encrypted": False}, + {"trusted_certificates": TrustSystemCAs()}, + {"trusted_certificates": TrustAll()}, + {"trusted_certificates": TrustCustomCAs("foo", "bar")}, +)) +@mark_async_test +async def test_no_ssl(config): + pool_config = AsyncPoolConfig.consume(config) + assert pool_config.encrypted is False + assert await pool_config.get_ssl_context() is None + # test caching + assert await pool_config.get_ssl_context() is None + + +@pytest.mark.parametrize("config", ( + {"encrypted": True}, + {"encrypted": True, "trusted_certificates": TrustSystemCAs()}, +)) +@mark_async_test +async def test_trust_system_cas(config): + pool_config = AsyncPoolConfig.consume(config) + assert pool_config.encrypted is True + ssl_context = await pool_config.get_ssl_context() + assert isinstance(ssl_context, ssl.SSLContext) + _assert_context_tls_1_2(ssl_context) + assert ssl_context.check_hostname is True + assert ssl_context.verify_mode == ssl.CERT_REQUIRED + # test caching + assert await pool_config.get_ssl_context() is ssl_context + + +@pytest.mark.parametrize("config", ( + {"encrypted": True, "trusted_certificates": TrustCustomCAs()}, +)) +@mark_async_test +async def test_trust_custom_cas(config): + pool_config = AsyncPoolConfig.consume(config) + assert pool_config.encrypted is True + ssl_context = await pool_config.get_ssl_context() + assert isinstance(ssl_context, ssl.SSLContext) + _assert_context_tls_1_2(ssl_context) + assert ssl_context.check_hostname is True + assert ssl_context.verify_mode == ssl.CERT_REQUIRED + # test caching + assert await pool_config.get_ssl_context() is ssl_context + + +@pytest.mark.parametrize("config", ( + {"encrypted": True, "trusted_certificates": TrustAll()}, +)) +@mark_async_test +async def test_trust_all(config): + pool_config = AsyncPoolConfig.consume(config) + assert pool_config.encrypted is True + ssl_context = await pool_config.get_ssl_context() + assert isinstance(ssl_context, ssl.SSLContext) + _assert_context_tls_1_2(ssl_context) + assert ssl_context.check_hostname is False + assert ssl_context.verify_mode is ssl.CERT_NONE + # test caching + assert await pool_config.get_ssl_context() is ssl_context + + +def _assert_context_tls_1_2(ctx): + assert ctx.protocol == ssl.PROTOCOL_TLS_CLIENT + assert ctx.minimum_version == ssl.TLSVersion.TLSv1_2 + + +@pytest.mark.parametrize("encrypted", (True, False)) +@pytest.mark.parametrize("trusted_certificates", ( + TrustSystemCAs(), TrustAll(), TrustCustomCAs() +)) +@mark_async_test +async def test_custom_ssl_context(encrypted, trusted_certificates): + custom_ssl_context = object() + pool_config = AsyncPoolConfig.consume({ + "encrypted": encrypted, + "trusted_certificates": trusted_certificates, + "ssl_context": custom_ssl_context, + }) + assert pool_config.encrypted is encrypted + assert await pool_config.get_ssl_context() is custom_ssl_context + # test caching + assert await pool_config.get_ssl_context() is custom_ssl_context + + +@pytest.mark.parametrize("trusted_certificates", ( + TrustSystemCAs(), TrustAll(), TrustCustomCAs() +)) +@mark_async_test +async def test_client_certificate(trusted_certificates, mocker) -> None: + ssl_context_mock = mocker.patch("ssl.SSLContext", autospec=True) + + with pytest.warns(PreviewWarning, match="Mutual TLS"): + cert = ClientCertificate("certfile", "keyfile", "password") + with pytest.warns(PreviewWarning, match="Mutual TLS"): + provider = AsyncClientCertificateProviders.rotating(cert) + pool_config = AsyncPoolConfig.consume({ + "client_certificate": provider, + "encrypted": True, + }) + assert pool_config.client_certificate is provider + + ssl_context = await pool_config.get_ssl_context() + + assert ssl_context is ssl_context_mock.return_value + ssl_context_mock.return_value.load_cert_chain.assert_called_with( + cert.certfile, + keyfile=cert.keyfile, + password=cert.password, + ) + + # test caching + ssl_context_mock.return_value.reset_mock() + ssl_context_mock.reset_mock() + assert await pool_config.get_ssl_context() is ssl_context + ssl_context_mock.return_value.load_cert_chain.assert_not_called() + ssl_context_mock.assert_not_called() + + # test cache invalidation + with pytest.warns(PreviewWarning, match="Mutual TLS"): + cert2 = ClientCertificate("certfile2", "keyfile2", "password2") + await provider.update_certificate(cert2) + + ssl_context = await pool_config.get_ssl_context() + + assert ssl_context is ssl_context_mock.return_value + ssl_context_mock.return_value.load_cert_chain.assert_called_with( + cert2.certfile, + keyfile=cert2.keyfile, + password=cert2.password, + ) + + # test caching + ssl_context_mock.return_value.reset_mock() + ssl_context_mock.reset_mock() + assert await pool_config.get_ssl_context() is ssl_context + ssl_context_mock.return_value.load_cert_chain.assert_not_called() + ssl_context_mock.assert_not_called() diff --git a/tests/unit/async_/test_driver.py b/tests/unit/async_/test_driver.py index e0cf76fe6..a2285b2be 100644 --- a/tests/unit/async_/test_driver.py +++ b/tests/unit/async_/test_driver.py @@ -41,21 +41,24 @@ TrustSystemCAs, ) from neo4j._api import TelemetryAPI +from neo4j._async.auth_management import _AsyncStaticClientCertificateProvider +from neo4j._async.config import AsyncPoolConfig from neo4j._async.driver import _work from neo4j._async.io import ( AsyncBoltPool, AsyncNeo4jPool, ) -from neo4j._conf import ( - PoolConfig, - SessionConfig, -) +from neo4j._conf import SessionConfig from neo4j.api import ( AsyncBookmarkManager, BookmarkManager, READ_ACCESS, WRITE_ACCESS, ) +from neo4j.auth_management import ( + AsyncClientCertificateProvider, + ClientCertificate, +) from neo4j.exceptions import ConfigurationError from ..._async_compat import ( @@ -442,6 +445,52 @@ def forget(self, databases: t.Iterable[str]) -> None: assert session_cls_mock.call_args[0][1].bookmark_manager is bmm +@mark_async_test +async def test_with_static_client_certificate() -> None: + with pytest.warns(neo4j.PreviewWarning, match="Mutual TLS"): + cert = ClientCertificate("foo") + with pytest.warns(neo4j.PreviewWarning, match="Mutual TLS"): + async with AsyncGraphDatabase.driver( + "bolt://localhost", client_certificate=cert + ) as driver: + passed_provider = driver._pool.pool_config.client_certificate + assert isinstance(passed_provider, + _AsyncStaticClientCertificateProvider) + assert passed_provider._cert is cert + + +@mark_async_test +async def test_with_custom_inherited_client_certificate_provider( + session_cls_mock +) -> None: + class Provider(AsyncClientCertificateProvider): + async def get_certificate(self) -> t.Optional[ClientCertificate]: + return None + + provider = Provider() + with pytest.warns(neo4j.PreviewWarning, match="Mutual TLS"): + async with AsyncGraphDatabase.driver( + "bolt://localhost", client_certificate=provider + ) as driver: + assert driver._pool.pool_config.client_certificate is provider + + +@mark_async_test +async def test_with_custom_ducktype_client_certificate_provider( + session_cls_mock +) -> None: + class Provider: + async def get_certificate(self) -> t.Optional[ClientCertificate]: + return None + + provider = Provider() + with pytest.warns(neo4j.PreviewWarning, match="Mutual TLS"): + async with AsyncGraphDatabase.driver( + "bolt://localhost", client_certificate=provider + ) as driver: + assert driver._pool.pool_config.client_certificate is provider + + _T_NotificationMinimumSeverity = t.Union[ NotificationMinimumSeverity, te.Literal[ @@ -526,7 +575,7 @@ async def test_driver_factory_with_notification_filters( ) async with driver: - default_conf = PoolConfig() + default_conf = AsyncPoolConfig() if min_sev is None: expected_min_sev = min_sev elif min_sev is not ...: diff --git a/tests/unit/common/test_conf.py b/tests/unit/common/test_conf.py index 4e0ce912c..f618b3544 100644 --- a/tests/unit/common/test_conf.py +++ b/tests/unit/common/test_conf.py @@ -14,25 +14,14 @@ # limitations under the License. -import ssl - import pytest -from neo4j import ( - TrustAll, - TrustCustomCAs, - TrustSystemCAs, -) from neo4j._conf import ( - Config, - PoolConfig, SessionConfig, WorkspaceConfig, ) from neo4j.api import ( READ_ACCESS, - TRUST_ALL_CERTIFICATES, - TRUST_SYSTEM_CA_SIGNED_CERTIFICATES, WRITE_ACCESS, ) from neo4j.debug import watch @@ -43,23 +32,6 @@ watch("neo4j") -test_pool_config = { - "connection_timeout": 30.0, - "keep_alive": True, - "max_connection_lifetime": 3600, - "liveness_check_timeout": None, - "max_connection_pool_size": 100, - "resolver": None, - "encrypted": False, - "user_agent": "test", - "trusted_certificates": TrustSystemCAs(), - "ssl_context": None, - "auth": None, - "notifications_min_severity": None, - "notifications_disabled_categories": None, - "telemetry_disabled": False, -} - test_session_config = { "connection_acquisition_timeout": 60.0, "max_transaction_retry_time": 30.0, @@ -78,146 +50,6 @@ } -def test_pool_config_consume(): - - test_config = dict(test_pool_config) - - consumed_pool_config = PoolConfig.consume(test_config) - - assert isinstance(consumed_pool_config, PoolConfig) - - assert len(test_config) == 0 - - for key in test_pool_config.keys(): - assert consumed_pool_config[key] == test_pool_config[key] - - for key in consumed_pool_config.keys(): - assert test_pool_config[key] == consumed_pool_config[key] - - assert len(consumed_pool_config) == len(test_pool_config) - - -def test_pool_config_consume_default_values(): - - test_config = {} - - consumed_pool_config = PoolConfig.consume(test_config) - - assert isinstance(consumed_pool_config, PoolConfig) - - assert len(test_config) == 0 - - consumed_pool_config.keep_alive = "changed" - - assert PoolConfig.keep_alive != consumed_pool_config.keep_alive - - -def test_pool_config_consume_key_not_valid(): - - test_config = dict(test_pool_config) - - test_config["not_valid_key"] = "test" - - with pytest.raises(ConfigurationError) as error: - PoolConfig.consume(test_config) - - error.match("Unexpected config keys: not_valid_key") - - -def test_pool_config_set_value(): - - test_config = dict(test_pool_config) - - consumed_pool_config = PoolConfig.consume(test_config) - - assert consumed_pool_config.get("encrypted") is False - assert consumed_pool_config["encrypted"] is False - assert consumed_pool_config.encrypted is False - - consumed_pool_config.encrypted = "test" - - assert consumed_pool_config.get("encrypted") == "test" - assert consumed_pool_config["encrypted"] == "test" - assert consumed_pool_config.encrypted == "test" - - consumed_pool_config.not_valid_key = "test" # Use consume functions - - -def test_pool_config_consume_and_then_consume_again(): - test_config = dict(test_pool_config) - consumed_pool_config = PoolConfig.consume(test_config) - assert consumed_pool_config.encrypted is False - consumed_pool_config.encrypted = "test" - - with pytest.raises(AttributeError): - consumed_pool_config = PoolConfig.consume(consumed_pool_config) - - consumed_pool_config = PoolConfig.consume(dict(consumed_pool_config.items())) - consumed_pool_config = PoolConfig.consume(dict(consumed_pool_config.items())) - - assert consumed_pool_config.encrypted == "test" - - -@pytest.mark.parametrize( - ("value_trust", "expected_trusted_certificates_cls"), - ( - (TRUST_ALL_CERTIFICATES, TrustAll), - (TRUST_SYSTEM_CA_SIGNED_CERTIFICATES, TrustSystemCAs), - ) -) -def test_pool_config_deprecated_trust_config( - value_trust, expected_trusted_certificates_cls -): - with pytest.warns(DeprecationWarning, match="trust.*trusted_certificates"): - consumed_pool_config = PoolConfig.consume({"trust": value_trust}) - assert isinstance(consumed_pool_config.trusted_certificates, - expected_trusted_certificates_cls) - assert not hasattr(consumed_pool_config, "trust") - - -@pytest.mark.parametrize("value_trust", ( - TRUST_ALL_CERTIFICATES, TRUST_SYSTEM_CA_SIGNED_CERTIFICATES -)) -@pytest.mark.parametrize("trusted_certificates", ( - TrustSystemCAs(), TrustAll(), TrustCustomCAs("foo"), - TrustCustomCAs("foo", "bar") -)) -def test_pool_config_deprecated_and_new_trust_config(value_trust, - trusted_certificates): - with pytest.raises(ConfigurationError, - match="trusted_certificates.*trust"): - PoolConfig.consume({"trust": value_trust, - "trusted_certificates": trusted_certificates}) - - -def test_config_consume_chain(): - - test_config = {} - - test_config.update(test_pool_config) - - test_config.update(test_session_config) - - consumed_pool_config, consumed_session_config = Config.consume_chain( - test_config, PoolConfig, SessionConfig - ) - - assert isinstance(consumed_pool_config, PoolConfig) - assert isinstance(consumed_session_config, SessionConfig) - - assert len(test_config) == 0 - - for key, val in test_pool_config.items(): - assert consumed_pool_config[key] == val - - for key, val in consumed_pool_config.items(): - assert test_pool_config[key] == val - - assert len(consumed_pool_config) == len(test_pool_config) - - assert len(consumed_session_config) == len(test_session_config) - - def test_init_session_config_merge(): # python -m pytest tests/unit/test_conf.py -s -v -k test_init_session_config @@ -263,150 +95,3 @@ def test_init_session_config_with_not_valid_key(): _ = SessionConfig.consume(test_config_b) assert session_config.connection_acquisition_timeout == 333 - - -@pytest.mark.parametrize("config", ( - {}, - {"encrypted": False}, - {"trusted_certificates": TrustSystemCAs()}, - {"trusted_certificates": TrustAll()}, - {"trusted_certificates": TrustCustomCAs("foo", "bar")}, -)) -def test_no_ssl_mock(config, mocker): - ssl_context_mock = mocker.patch("ssl.SSLContext", autospec=True) - pool_config = PoolConfig.consume(config) - assert pool_config.encrypted is False - assert pool_config.get_ssl_context() is None - ssl_context_mock.assert_not_called() - - -@pytest.mark.parametrize("config", ( - {"encrypted": True}, - {"encrypted": True, "trusted_certificates": TrustSystemCAs()}, -)) -def test_trust_system_cas_mock(config, mocker): - ssl_context_mock = mocker.patch("ssl.SSLContext", autospec=True) - pool_config = PoolConfig.consume(config) - assert pool_config.encrypted is True - ssl_context = pool_config.get_ssl_context() - _assert_mock_tls_1_2(ssl_context_mock) - assert ssl_context.minimum_version == ssl.TLSVersion.TLSv1_2 - ssl_context_mock.return_value.load_default_certs.assert_called_once_with() - ssl_context_mock.return_value.load_verify_locations.assert_not_called() - assert ssl_context.check_hostname is True - assert ssl_context.verify_mode == ssl.CERT_REQUIRED - - -@pytest.mark.parametrize("config", ( - {"encrypted": True, "trusted_certificates": TrustCustomCAs("foo", "bar")}, - {"encrypted": True, "trusted_certificates": TrustCustomCAs()}, -)) -def test_trust_custom_cas_mock(config, mocker): - ssl_context_mock = mocker.patch("ssl.SSLContext", autospec=True) - certs = config["trusted_certificates"].certs - pool_config = PoolConfig.consume(config) - assert pool_config.encrypted is True - ssl_context = pool_config.get_ssl_context() - _assert_mock_tls_1_2(ssl_context_mock) - assert ssl_context.minimum_version == ssl.TLSVersion.TLSv1_2 - ssl_context_mock.return_value.load_default_certs.assert_not_called() - assert ( - ssl_context_mock.return_value.load_verify_locations.call_args_list - == [((cert,), {}) for cert in certs] - ) - assert ssl_context.check_hostname is True - assert ssl_context.verify_mode == ssl.CERT_REQUIRED - - -@pytest.mark.parametrize("config", ( - {"encrypted": True, "trusted_certificates": TrustAll()}, -)) -def test_trust_all_mock(config, mocker): - ssl_context_mock = mocker.patch("ssl.SSLContext", autospec=True) - pool_config = PoolConfig.consume(config) - assert pool_config.encrypted is True - ssl_context = pool_config.get_ssl_context() - _assert_mock_tls_1_2(ssl_context_mock) - assert ssl_context.minimum_version == ssl.TLSVersion.TLSv1_2 - ssl_context_mock.return_value.load_default_certs.assert_not_called() - ssl_context_mock.return_value.load_verify_locations.assert_not_called() - assert ssl_context.check_hostname is False - assert ssl_context.verify_mode is ssl.CERT_NONE - - -def _assert_mock_tls_1_2(mock): - mock.assert_called_once_with(ssl.PROTOCOL_TLS_CLIENT) - assert mock.return_value.minimum_version == ssl.TLSVersion.TLSv1_2 - - -@pytest.mark.parametrize("config", ( - {}, - {"encrypted": False}, - {"trusted_certificates": TrustSystemCAs()}, - {"trusted_certificates": TrustAll()}, - {"trusted_certificates": TrustCustomCAs("foo", "bar")}, -)) -def test_no_ssl(config): - pool_config = PoolConfig.consume(config) - assert pool_config.encrypted is False - assert pool_config.get_ssl_context() is None - - -@pytest.mark.parametrize("config", ( - {"encrypted": True}, - {"encrypted": True, "trusted_certificates": TrustSystemCAs()}, -)) -def test_trust_system_cas(config): - pool_config = PoolConfig.consume(config) - assert pool_config.encrypted is True - ssl_context = pool_config.get_ssl_context() - assert isinstance(ssl_context, ssl.SSLContext) - _assert_context_tls_1_2(ssl_context) - assert ssl_context.check_hostname is True - assert ssl_context.verify_mode == ssl.CERT_REQUIRED - - -@pytest.mark.parametrize("config", ( - {"encrypted": True, "trusted_certificates": TrustCustomCAs()}, -)) -def test_trust_custom_cas(config): - pool_config = PoolConfig.consume(config) - assert pool_config.encrypted is True - ssl_context = pool_config.get_ssl_context() - assert isinstance(ssl_context, ssl.SSLContext) - _assert_context_tls_1_2(ssl_context) - assert ssl_context.check_hostname is True - assert ssl_context.verify_mode == ssl.CERT_REQUIRED - - -@pytest.mark.parametrize("config", ( - {"encrypted": True, "trusted_certificates": TrustAll()}, -)) -def test_trust_all(config): - pool_config = PoolConfig.consume(config) - assert pool_config.encrypted is True - ssl_context = pool_config.get_ssl_context() - assert isinstance(ssl_context, ssl.SSLContext) - _assert_context_tls_1_2(ssl_context) - assert ssl_context.check_hostname is False - assert ssl_context.verify_mode is ssl.CERT_NONE - - -def _assert_context_tls_1_2(ctx): - assert ctx.protocol == ssl.PROTOCOL_TLS_CLIENT - assert ctx.minimum_version == ssl.TLSVersion.TLSv1_2 - - -@pytest.mark.parametrize("encrypted", (True, False)) -@pytest.mark.parametrize("trusted_certificates", ( - TrustSystemCAs(), TrustAll(), TrustCustomCAs() -)) -def test_custom_ssl_context(encrypted, trusted_certificates): - custom_ssl_context = object() - pool_config = PoolConfig.consume({ - "encrypted": encrypted, - "trusted_certificates": trusted_certificates, - "ssl_context": custom_ssl_context, - }) - assert pool_config.encrypted is encrypted - assert pool_config.get_ssl_context() is custom_ssl_context diff --git a/tests/unit/mixed/io/test_direct.py b/tests/unit/mixed/io/test_direct.py index 6f9fb8f89..63255d254 100644 --- a/tests/unit/mixed/io/test_direct.py +++ b/tests/unit/mixed/io/test_direct.py @@ -32,12 +32,12 @@ from ...async_.conftest import async_fake_connection_generator from ...async_.io.test_direct import AsyncFakeBoltPool -from ...async_.test_auth_manager import ( +from ...async_.test_auth_management import ( static_auth_manager as static_async_auth_manager, ) from ...sync.conftest import fake_connection_generator from ...sync.io.test_direct import FakeBoltPool -from ...sync.test_auth_manager import static_auth_manager +from ...sync.test_auth_management import static_auth_manager from ._common import ( AsyncMultiEvent, MultiEvent, diff --git a/tests/unit/sync/fixtures/fake_pool.py b/tests/unit/sync/fixtures/fake_pool.py index 11a23f667..38d2ac4d1 100644 --- a/tests/unit/sync/fixtures/fake_pool.py +++ b/tests/unit/sync/fixtures/fake_pool.py @@ -16,7 +16,7 @@ import pytest -from neo4j._conf import PoolConfig +from neo4j._sync.config import PoolConfig from neo4j._sync.io._pool import IOPool diff --git a/tests/unit/sync/io/test_class_bolt3.py b/tests/unit/sync/io/test_class_bolt3.py index b522845f5..ead813a38 100644 --- a/tests/unit/sync/io/test_class_bolt3.py +++ b/tests/unit/sync/io/test_class_bolt3.py @@ -22,8 +22,8 @@ import neo4j from neo4j._api import TelemetryAPI -from neo4j._conf import PoolConfig from neo4j._meta import USER_AGENT +from neo4j._sync.config import PoolConfig from neo4j._sync.io._bolt3 import Bolt3 from neo4j.exceptions import ConfigurationError diff --git a/tests/unit/sync/io/test_class_bolt4x0.py b/tests/unit/sync/io/test_class_bolt4x0.py index c30f90246..47e17d550 100644 --- a/tests/unit/sync/io/test_class_bolt4x0.py +++ b/tests/unit/sync/io/test_class_bolt4x0.py @@ -21,8 +21,8 @@ import neo4j from neo4j._api import TelemetryAPI -from neo4j._conf import PoolConfig from neo4j._meta import USER_AGENT +from neo4j._sync.config import PoolConfig from neo4j._sync.io._bolt4 import Bolt4x0 from neo4j.exceptions import ConfigurationError diff --git a/tests/unit/sync/io/test_class_bolt4x1.py b/tests/unit/sync/io/test_class_bolt4x1.py index 38ac7d1a2..0ce6f2d38 100644 --- a/tests/unit/sync/io/test_class_bolt4x1.py +++ b/tests/unit/sync/io/test_class_bolt4x1.py @@ -21,8 +21,8 @@ import neo4j from neo4j._api import TelemetryAPI -from neo4j._conf import PoolConfig from neo4j._meta import USER_AGENT +from neo4j._sync.config import PoolConfig from neo4j._sync.io._bolt4 import Bolt4x1 from neo4j.exceptions import ConfigurationError diff --git a/tests/unit/sync/io/test_class_bolt4x2.py b/tests/unit/sync/io/test_class_bolt4x2.py index 0fb704b6f..33b2e849a 100644 --- a/tests/unit/sync/io/test_class_bolt4x2.py +++ b/tests/unit/sync/io/test_class_bolt4x2.py @@ -21,8 +21,8 @@ import neo4j from neo4j._api import TelemetryAPI -from neo4j._conf import PoolConfig from neo4j._meta import USER_AGENT +from neo4j._sync.config import PoolConfig from neo4j._sync.io._bolt4 import Bolt4x2 from neo4j.exceptions import ConfigurationError diff --git a/tests/unit/sync/io/test_class_bolt4x3.py b/tests/unit/sync/io/test_class_bolt4x3.py index 883e991ca..0e370ac57 100644 --- a/tests/unit/sync/io/test_class_bolt4x3.py +++ b/tests/unit/sync/io/test_class_bolt4x3.py @@ -21,8 +21,8 @@ import neo4j from neo4j._api import TelemetryAPI -from neo4j._conf import PoolConfig from neo4j._meta import USER_AGENT +from neo4j._sync.config import PoolConfig from neo4j._sync.io._bolt4 import Bolt4x3 from neo4j.exceptions import ConfigurationError diff --git a/tests/unit/sync/io/test_class_bolt4x4.py b/tests/unit/sync/io/test_class_bolt4x4.py index e702b413e..384f3067b 100644 --- a/tests/unit/sync/io/test_class_bolt4x4.py +++ b/tests/unit/sync/io/test_class_bolt4x4.py @@ -21,8 +21,8 @@ import neo4j from neo4j._api import TelemetryAPI -from neo4j._conf import PoolConfig from neo4j._meta import USER_AGENT +from neo4j._sync.config import PoolConfig from neo4j._sync.io._bolt4 import Bolt4x4 from neo4j.exceptions import ConfigurationError diff --git a/tests/unit/sync/io/test_class_bolt5x0.py b/tests/unit/sync/io/test_class_bolt5x0.py index 256ee6242..0cec94e8c 100644 --- a/tests/unit/sync/io/test_class_bolt5x0.py +++ b/tests/unit/sync/io/test_class_bolt5x0.py @@ -21,8 +21,8 @@ import neo4j from neo4j._api import TelemetryAPI -from neo4j._conf import PoolConfig from neo4j._meta import USER_AGENT +from neo4j._sync.config import PoolConfig from neo4j._sync.io._bolt5 import Bolt5x0 from neo4j.exceptions import ConfigurationError diff --git a/tests/unit/sync/io/test_class_bolt5x1.py b/tests/unit/sync/io/test_class_bolt5x1.py index 77cf6c7a1..fd5db1c20 100644 --- a/tests/unit/sync/io/test_class_bolt5x1.py +++ b/tests/unit/sync/io/test_class_bolt5x1.py @@ -22,8 +22,8 @@ import neo4j import neo4j.exceptions from neo4j._api import TelemetryAPI -from neo4j._conf import PoolConfig from neo4j._meta import USER_AGENT +from neo4j._sync.config import PoolConfig from neo4j._sync.io._bolt5 import Bolt5x1 from neo4j.exceptions import ConfigurationError diff --git a/tests/unit/sync/io/test_class_bolt5x2.py b/tests/unit/sync/io/test_class_bolt5x2.py index 68da16e84..88becdebe 100644 --- a/tests/unit/sync/io/test_class_bolt5x2.py +++ b/tests/unit/sync/io/test_class_bolt5x2.py @@ -21,8 +21,8 @@ import neo4j from neo4j._api import TelemetryAPI -from neo4j._conf import PoolConfig from neo4j._meta import USER_AGENT +from neo4j._sync.config import PoolConfig from neo4j._sync.io._bolt5 import Bolt5x2 from ...._async_compat import mark_sync_test diff --git a/tests/unit/sync/io/test_class_bolt5x3.py b/tests/unit/sync/io/test_class_bolt5x3.py index 2cb20cc3f..3ed663bcf 100644 --- a/tests/unit/sync/io/test_class_bolt5x3.py +++ b/tests/unit/sync/io/test_class_bolt5x3.py @@ -21,11 +21,11 @@ import neo4j from neo4j._api import TelemetryAPI -from neo4j._conf import PoolConfig from neo4j._meta import ( BOLT_AGENT_DICT, USER_AGENT, ) +from neo4j._sync.config import PoolConfig from neo4j._sync.io._bolt5 import Bolt5x3 from ...._async_compat import mark_sync_test diff --git a/tests/unit/sync/io/test_class_bolt5x4.py b/tests/unit/sync/io/test_class_bolt5x4.py index a002f3f4d..4ed3ad938 100644 --- a/tests/unit/sync/io/test_class_bolt5x4.py +++ b/tests/unit/sync/io/test_class_bolt5x4.py @@ -21,11 +21,11 @@ import neo4j from neo4j._api import TelemetryAPI -from neo4j._conf import PoolConfig from neo4j._meta import ( BOLT_AGENT_DICT, USER_AGENT, ) +from neo4j._sync.config import PoolConfig from neo4j._sync.io._bolt5 import Bolt5x4 from ...._async_compat import mark_sync_test diff --git a/tests/unit/sync/io/test_direct.py b/tests/unit/sync/io/test_direct.py index 9723d7f0c..6a2c620c2 100644 --- a/tests/unit/sync/io/test_direct.py +++ b/tests/unit/sync/io/test_direct.py @@ -19,10 +19,10 @@ import neo4j from neo4j._conf import ( Config, - PoolConfig, WorkspaceConfig, ) from neo4j._deadline import Deadline +from neo4j._sync.config import PoolConfig from neo4j._sync.io import Bolt from neo4j._sync.io._pool import IOPool from neo4j.auth_management import AuthManagers @@ -40,7 +40,9 @@ class FakeBoltPool(IOPool): def __init__(self, connection_gen, address, *, auth=None, **config): self.buffered_connection_mocks = [] config["auth"] = static_auth(None) - self.pool_config, self.workspace_config = Config.consume_chain(config, PoolConfig, WorkspaceConfig) + self.pool_config, self.workspace_config = Config.consume_chain( + config, PoolConfig, WorkspaceConfig + ) if config: raise ValueError("Unexpected config keys: %s" % ", ".join(config.keys())) diff --git a/tests/unit/sync/io/test_neo4j_pool.py b/tests/unit/sync/io/test_neo4j_pool.py index 87c2ce13c..01943c532 100644 --- a/tests/unit/sync/io/test_neo4j_pool.py +++ b/tests/unit/sync/io/test_neo4j_pool.py @@ -20,17 +20,16 @@ import pytest from neo4j import ( - PreviewWarning, READ_ACCESS, WRITE_ACCESS, ) from neo4j._async_compat.util import Util from neo4j._conf import ( - PoolConfig, RoutingConfig, WorkspaceConfig, ) from neo4j._deadline import Deadline +from neo4j._sync.config import PoolConfig from neo4j._sync.io import ( Bolt, Neo4jPool, diff --git a/tests/unit/sync/test_auth_manager.py b/tests/unit/sync/test_auth_management.py similarity index 72% rename from tests/unit/sync/test_auth_manager.py rename to tests/unit/sync/test_auth_management.py index fbe4c8006..598877daf 100644 --- a/tests/unit/sync/test_auth_manager.py +++ b/tests/unit/sync/test_auth_management.py @@ -30,7 +30,11 @@ from neo4j.auth_management import ( AuthManager, AuthManagers, + ClientCertificate, + ClientCertificateProvider, + ClientCertificateProviders, ExpiringAuth, + RotatingClientCertificateProvider, ) from neo4j.exceptions import Neo4jError @@ -237,3 +241,88 @@ def _test_manager( else: assert manager.get_auth() is auth1 provider.assert_not_called() + + +@pytest.fixture +def client_cert_factory() -> t.Callable[[], ClientCertificate]: + i = 0 + + def factory() -> ClientCertificate: + with pytest.warns(PreviewWarning, match="Mutual TLS"): + return ClientCertificate(f"cert{i}") + + return factory + + +@copy_signature(ClientCertificateProviders.static) +def static_cert_provider(*args, **kwargs): + with pytest.warns(PreviewWarning, match="Mutual TLS"): + return ClientCertificateProviders.static(*args, **kwargs) + + +@copy_signature(RotatingClientCertificateProvider) +def rotating_cert_provider_direct(*args, **kwargs): + with pytest.warns(PreviewWarning, match="Mutual TLS"): + return RotatingClientCertificateProvider(*args, **kwargs) + + +@copy_signature(ClientCertificateProviders.rotating) +def rotating_cert_provider(*args, **kwargs): + with pytest.warns(PreviewWarning, match="Mutual TLS"): + return ClientCertificateProviders.rotating(*args, **kwargs) + + +@mark_sync_test +def test_static_client_cert_provider(client_cert_factory) -> None: + cert1: ClientCertificate = client_cert_factory() + provider: ClientCertificateProvider = static_cert_provider(cert1) + + assert provider.get_certificate() is cert1 + for _ in range(10): + assert provider.get_certificate() is None + + +if t.TYPE_CHECKING: + # Tests for type checker only. No need to run the test. + + def test_rotating_client_cert_provider_type_init( + client_cert_factory + ) -> None: + cert1: ClientCertificate = client_cert_factory() + provider: RotatingClientCertificateProvider = \ + rotating_cert_provider_direct(cert1) + _: ClientCertificateProvider = provider + + + def test_rotating_client_cert_provider_type_factory( + client_cert_factory + ) -> None: + cert1: ClientCertificate = client_cert_factory() + provider: RotatingClientCertificateProvider = \ + rotating_cert_provider(cert1) + _: ClientCertificateProvider = provider + + +@pytest.mark.parametrize( + "factory", (rotating_cert_provider, rotating_cert_provider_direct) +) +@mark_sync_test +def test_rotating_client_cert_provider( + factory: t.Callable[[ClientCertificate], + RotatingClientCertificateProvider], + client_cert_factory +) -> None: + cert1: ClientCertificate = client_cert_factory() + cert2: ClientCertificate = client_cert_factory() + assert cert1 is not cert2 # sanity check + provider: RotatingClientCertificateProvider = factory(cert1) + + assert provider.get_certificate() is cert1 + for _ in range(10): + assert provider.get_certificate() is None + + provider.update_certificate(cert2) + + assert provider.get_certificate() is cert2 + for _ in range(10): + assert provider.get_certificate() is None diff --git a/tests/unit/sync/test_conf.py b/tests/unit/sync/test_conf.py new file mode 100644 index 000000000..e78f99328 --- /dev/null +++ b/tests/unit/sync/test_conf.py @@ -0,0 +1,442 @@ +# Copyright (c) "Neo4j" +# Neo4j Sweden AB [https://neo4j.com] +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import ssl + +import pytest + +from neo4j import ( + PreviewWarning, + TrustAll, + TrustCustomCAs, + TrustSystemCAs, +) +from neo4j._conf import ( + Config, + SessionConfig, +) +from neo4j._sync.config import PoolConfig +from neo4j.api import ( + TRUST_ALL_CERTIFICATES, + TRUST_SYSTEM_CA_SIGNED_CERTIFICATES, +) +from neo4j.auth_management import ( + ClientCertificate, + ClientCertificateProviders, +) +from neo4j.debug import watch +from neo4j.exceptions import ConfigurationError + +from ..._async_compat import mark_sync_test +from ..common.test_conf import test_session_config + + +# python -m pytest tests/unit/test_conf.py -s -v + +watch("neo4j") + +test_pool_config = { + "connection_timeout": 30.0, + "keep_alive": True, + "max_connection_lifetime": 3600, + "liveness_check_timeout": None, + "max_connection_pool_size": 100, + "resolver": None, + "encrypted": False, + "user_agent": "test", + "trusted_certificates": TrustSystemCAs(), + "client_certificate": None, + "ssl_context": None, + "auth": None, + "notifications_min_severity": None, + "notifications_disabled_categories": None, + "telemetry_disabled": False, +} + + +def test_pool_config_consume(): + + test_config = dict(test_pool_config) + + consumed_pool_config = PoolConfig.consume(test_config) + + assert isinstance(consumed_pool_config, PoolConfig) + + assert len(test_config) == 0 + + for key in test_pool_config.keys(): + assert consumed_pool_config[key] == test_pool_config[key] + + for key in consumed_pool_config.keys(): + assert test_pool_config[key] == consumed_pool_config[key] + + assert len(consumed_pool_config) == len(test_pool_config) + + +def test_pool_config_consume_default_values(): + + test_config = {} + + consumed_pool_config = PoolConfig.consume(test_config) + + assert isinstance(consumed_pool_config, PoolConfig) + + assert len(test_config) == 0 + + consumed_pool_config.keep_alive = "changed" + + assert PoolConfig.keep_alive != consumed_pool_config.keep_alive + + +def test_pool_config_consume_key_not_valid(): + + test_config = dict(test_pool_config) + + test_config["not_valid_key"] = "test" + + with pytest.raises(ConfigurationError) as error: + PoolConfig.consume(test_config) + + error.match("Unexpected config keys: not_valid_key") + + +def test_pool_config_set_value(): + + test_config = dict(test_pool_config) + + consumed_pool_config = PoolConfig.consume(test_config) + + assert consumed_pool_config.get("encrypted") is False + assert consumed_pool_config["encrypted"] is False + assert consumed_pool_config.encrypted is False + + consumed_pool_config.encrypted = "test" + + assert consumed_pool_config.get("encrypted") == "test" + assert consumed_pool_config["encrypted"] == "test" + assert consumed_pool_config.encrypted == "test" + + consumed_pool_config.not_valid_key = "test" # Use consume functions + + +def test_pool_config_consume_and_then_consume_again(): + test_config = dict(test_pool_config) + consumed_pool_config = PoolConfig.consume(test_config) + assert consumed_pool_config.encrypted is False + consumed_pool_config.encrypted = "test" + + with pytest.raises(AttributeError): + consumed_pool_config = PoolConfig.consume(consumed_pool_config) + + consumed_pool_config = PoolConfig.consume(dict(consumed_pool_config.items())) + consumed_pool_config = PoolConfig.consume(dict(consumed_pool_config.items())) + + assert consumed_pool_config.encrypted == "test" + + +@pytest.mark.parametrize( + ("value_trust", "expected_trusted_certificates_cls"), + ( + (TRUST_ALL_CERTIFICATES, TrustAll), + (TRUST_SYSTEM_CA_SIGNED_CERTIFICATES, TrustSystemCAs), + ) +) +def test_pool_config_deprecated_trust_config( + value_trust, expected_trusted_certificates_cls +): + with pytest.warns(DeprecationWarning, match="trust.*trusted_certificates"): + consumed_pool_config = PoolConfig.consume({"trust": value_trust}) + assert isinstance(consumed_pool_config.trusted_certificates, + expected_trusted_certificates_cls) + assert not hasattr(consumed_pool_config, "trust") + + +@pytest.mark.parametrize("value_trust", ( + TRUST_ALL_CERTIFICATES, TRUST_SYSTEM_CA_SIGNED_CERTIFICATES +)) +@pytest.mark.parametrize("trusted_certificates", ( + TrustSystemCAs(), TrustAll(), TrustCustomCAs("foo"), + TrustCustomCAs("foo", "bar") +)) +def test_pool_config_deprecated_and_new_trust_config(value_trust, + trusted_certificates): + with pytest.raises(ConfigurationError, + match="trusted_certificates.*trust"): + PoolConfig.consume({ + "trust": value_trust, + "trusted_certificates": trusted_certificates} + ) + + +def test_config_consume_chain(): + test_config = {} + + test_config.update(test_pool_config) + + test_config.update(test_session_config) + + consumed_pool_config, consumed_session_config = Config.consume_chain( + test_config, PoolConfig, SessionConfig + ) + + assert isinstance(consumed_pool_config, PoolConfig) + assert isinstance(consumed_session_config, SessionConfig) + + assert len(test_config) == 0 + + for key, val in test_pool_config.items(): + assert consumed_pool_config[key] == val + + for key, val in consumed_pool_config.items(): + assert test_pool_config[key] == val + + assert len(consumed_pool_config) == len(test_pool_config) + + assert len(consumed_session_config) == len(test_session_config) + + +@pytest.mark.parametrize("config", ( + {}, + {"encrypted": False}, + {"trusted_certificates": TrustSystemCAs()}, + {"trusted_certificates": TrustAll()}, + {"trusted_certificates": TrustCustomCAs("foo", "bar")}, +)) +@mark_sync_test +def test_no_ssl_mock(config, mocker): + ssl_context_mock = mocker.patch("ssl.SSLContext", autospec=True) + pool_config = PoolConfig.consume(config) + assert pool_config.encrypted is False + assert pool_config.get_ssl_context() is None + ssl_context_mock.assert_not_called() + # test caching + assert pool_config.get_ssl_context() is None + ssl_context_mock.assert_not_called() + + +@pytest.mark.parametrize("config", ( + {"encrypted": True}, + {"encrypted": True, "trusted_certificates": TrustSystemCAs()}, +)) +@mark_sync_test +def test_trust_system_cas_mock(config, mocker): + ssl_context_mock = mocker.patch("ssl.SSLContext", autospec=True) + pool_config = PoolConfig.consume(config) + assert pool_config.encrypted is True + ssl_context = pool_config.get_ssl_context() + _assert_mock_tls_1_2(ssl_context_mock) + assert ssl_context.minimum_version == ssl.TLSVersion.TLSv1_2 + ssl_context_mock.return_value.load_default_certs.assert_called_once_with() + ssl_context_mock.return_value.load_verify_locations.assert_not_called() + assert ssl_context.check_hostname is True + assert ssl_context.verify_mode == ssl.CERT_REQUIRED + # test caching + ssl_context_mock.reset_mock() + assert pool_config.get_ssl_context() is ssl_context + ssl_context_mock.assert_not_called() + + +@pytest.mark.parametrize("config", ( + {"encrypted": True, "trusted_certificates": TrustCustomCAs("foo", "bar")}, + {"encrypted": True, "trusted_certificates": TrustCustomCAs()}, +)) +@mark_sync_test +def test_trust_custom_cas_mock(config, mocker): + ssl_context_mock = mocker.patch("ssl.SSLContext", autospec=True) + certs = config["trusted_certificates"].certs + pool_config = PoolConfig.consume(config) + assert pool_config.encrypted is True + ssl_context = pool_config.get_ssl_context() + _assert_mock_tls_1_2(ssl_context_mock) + assert ssl_context.minimum_version == ssl.TLSVersion.TLSv1_2 + ssl_context_mock.return_value.load_default_certs.assert_not_called() + assert ( + ssl_context_mock.return_value.load_verify_locations.call_args_list + == [((cert,), {}) for cert in certs] + ) + assert ssl_context.check_hostname is True + assert ssl_context.verify_mode == ssl.CERT_REQUIRED + # test caching + assert pool_config.get_ssl_context() is ssl_context + + +@pytest.mark.parametrize("config", ( + {"encrypted": True, "trusted_certificates": TrustAll()}, +)) +@mark_sync_test +def test_trust_all_mock(config, mocker): + ssl_context_mock = mocker.patch("ssl.SSLContext", autospec=True) + pool_config = PoolConfig.consume(config) + assert pool_config.encrypted is True + ssl_context = pool_config.get_ssl_context() + _assert_mock_tls_1_2(ssl_context_mock) + assert ssl_context.minimum_version == ssl.TLSVersion.TLSv1_2 + ssl_context_mock.return_value.load_default_certs.assert_not_called() + ssl_context_mock.return_value.load_verify_locations.assert_not_called() + assert ssl_context.check_hostname is False + assert ssl_context.verify_mode is ssl.CERT_NONE + # test caching + ssl_context_mock.reset_mock() + assert pool_config.get_ssl_context() is ssl_context + ssl_context_mock.assert_not_called() + + +def _assert_mock_tls_1_2(mock): + mock.assert_called_once_with(ssl.PROTOCOL_TLS_CLIENT) + assert mock.return_value.minimum_version == ssl.TLSVersion.TLSv1_2 + + +@pytest.mark.parametrize("config", ( + {}, + {"encrypted": False}, + {"trusted_certificates": TrustSystemCAs()}, + {"trusted_certificates": TrustAll()}, + {"trusted_certificates": TrustCustomCAs("foo", "bar")}, +)) +@mark_sync_test +def test_no_ssl(config): + pool_config = PoolConfig.consume(config) + assert pool_config.encrypted is False + assert pool_config.get_ssl_context() is None + # test caching + assert pool_config.get_ssl_context() is None + + +@pytest.mark.parametrize("config", ( + {"encrypted": True}, + {"encrypted": True, "trusted_certificates": TrustSystemCAs()}, +)) +@mark_sync_test +def test_trust_system_cas(config): + pool_config = PoolConfig.consume(config) + assert pool_config.encrypted is True + ssl_context = pool_config.get_ssl_context() + assert isinstance(ssl_context, ssl.SSLContext) + _assert_context_tls_1_2(ssl_context) + assert ssl_context.check_hostname is True + assert ssl_context.verify_mode == ssl.CERT_REQUIRED + # test caching + assert pool_config.get_ssl_context() is ssl_context + + +@pytest.mark.parametrize("config", ( + {"encrypted": True, "trusted_certificates": TrustCustomCAs()}, +)) +@mark_sync_test +def test_trust_custom_cas(config): + pool_config = PoolConfig.consume(config) + assert pool_config.encrypted is True + ssl_context = pool_config.get_ssl_context() + assert isinstance(ssl_context, ssl.SSLContext) + _assert_context_tls_1_2(ssl_context) + assert ssl_context.check_hostname is True + assert ssl_context.verify_mode == ssl.CERT_REQUIRED + # test caching + assert pool_config.get_ssl_context() is ssl_context + + +@pytest.mark.parametrize("config", ( + {"encrypted": True, "trusted_certificates": TrustAll()}, +)) +@mark_sync_test +def test_trust_all(config): + pool_config = PoolConfig.consume(config) + assert pool_config.encrypted is True + ssl_context = pool_config.get_ssl_context() + assert isinstance(ssl_context, ssl.SSLContext) + _assert_context_tls_1_2(ssl_context) + assert ssl_context.check_hostname is False + assert ssl_context.verify_mode is ssl.CERT_NONE + # test caching + assert pool_config.get_ssl_context() is ssl_context + + +def _assert_context_tls_1_2(ctx): + assert ctx.protocol == ssl.PROTOCOL_TLS_CLIENT + assert ctx.minimum_version == ssl.TLSVersion.TLSv1_2 + + +@pytest.mark.parametrize("encrypted", (True, False)) +@pytest.mark.parametrize("trusted_certificates", ( + TrustSystemCAs(), TrustAll(), TrustCustomCAs() +)) +@mark_sync_test +def test_custom_ssl_context(encrypted, trusted_certificates): + custom_ssl_context = object() + pool_config = PoolConfig.consume({ + "encrypted": encrypted, + "trusted_certificates": trusted_certificates, + "ssl_context": custom_ssl_context, + }) + assert pool_config.encrypted is encrypted + assert pool_config.get_ssl_context() is custom_ssl_context + # test caching + assert pool_config.get_ssl_context() is custom_ssl_context + + +@pytest.mark.parametrize("trusted_certificates", ( + TrustSystemCAs(), TrustAll(), TrustCustomCAs() +)) +@mark_sync_test +def test_client_certificate(trusted_certificates, mocker) -> None: + ssl_context_mock = mocker.patch("ssl.SSLContext", autospec=True) + + with pytest.warns(PreviewWarning, match="Mutual TLS"): + cert = ClientCertificate("certfile", "keyfile", "password") + with pytest.warns(PreviewWarning, match="Mutual TLS"): + provider = ClientCertificateProviders.rotating(cert) + pool_config = PoolConfig.consume({ + "client_certificate": provider, + "encrypted": True, + }) + assert pool_config.client_certificate is provider + + ssl_context = pool_config.get_ssl_context() + + assert ssl_context is ssl_context_mock.return_value + ssl_context_mock.return_value.load_cert_chain.assert_called_with( + cert.certfile, + keyfile=cert.keyfile, + password=cert.password, + ) + + # test caching + ssl_context_mock.return_value.reset_mock() + ssl_context_mock.reset_mock() + assert pool_config.get_ssl_context() is ssl_context + ssl_context_mock.return_value.load_cert_chain.assert_not_called() + ssl_context_mock.assert_not_called() + + # test cache invalidation + with pytest.warns(PreviewWarning, match="Mutual TLS"): + cert2 = ClientCertificate("certfile2", "keyfile2", "password2") + provider.update_certificate(cert2) + + ssl_context = pool_config.get_ssl_context() + + assert ssl_context is ssl_context_mock.return_value + ssl_context_mock.return_value.load_cert_chain.assert_called_with( + cert2.certfile, + keyfile=cert2.keyfile, + password=cert2.password, + ) + + # test caching + ssl_context_mock.return_value.reset_mock() + ssl_context_mock.reset_mock() + assert pool_config.get_ssl_context() is ssl_context + ssl_context_mock.return_value.load_cert_chain.assert_not_called() + ssl_context_mock.assert_not_called() diff --git a/tests/unit/sync/test_driver.py b/tests/unit/sync/test_driver.py index 8095a8421..0a3c6c332 100644 --- a/tests/unit/sync/test_driver.py +++ b/tests/unit/sync/test_driver.py @@ -41,10 +41,9 @@ TrustSystemCAs, ) from neo4j._api import TelemetryAPI -from neo4j._conf import ( - PoolConfig, - SessionConfig, -) +from neo4j._conf import SessionConfig +from neo4j._sync.auth_management import _StaticClientCertificateProvider +from neo4j._sync.config import PoolConfig from neo4j._sync.driver import _work from neo4j._sync.io import ( BoltPool, @@ -55,6 +54,10 @@ READ_ACCESS, WRITE_ACCESS, ) +from neo4j.auth_management import ( + ClientCertificate, + ClientCertificateProvider, +) from neo4j.exceptions import ConfigurationError from ..._async_compat import ( @@ -441,6 +444,52 @@ def forget(self, databases: t.Iterable[str]) -> None: assert session_cls_mock.call_args[0][1].bookmark_manager is bmm +@mark_sync_test +def test_with_static_client_certificate() -> None: + with pytest.warns(neo4j.PreviewWarning, match="Mutual TLS"): + cert = ClientCertificate("foo") + with pytest.warns(neo4j.PreviewWarning, match="Mutual TLS"): + with GraphDatabase.driver( + "bolt://localhost", client_certificate=cert + ) as driver: + passed_provider = driver._pool.pool_config.client_certificate + assert isinstance(passed_provider, + _StaticClientCertificateProvider) + assert passed_provider._cert is cert + + +@mark_sync_test +def test_with_custom_inherited_client_certificate_provider( + session_cls_mock +) -> None: + class Provider(ClientCertificateProvider): + def get_certificate(self) -> t.Optional[ClientCertificate]: + return None + + provider = Provider() + with pytest.warns(neo4j.PreviewWarning, match="Mutual TLS"): + with GraphDatabase.driver( + "bolt://localhost", client_certificate=provider + ) as driver: + assert driver._pool.pool_config.client_certificate is provider + + +@mark_sync_test +def test_with_custom_ducktype_client_certificate_provider( + session_cls_mock +) -> None: + class Provider: + def get_certificate(self) -> t.Optional[ClientCertificate]: + return None + + provider = Provider() + with pytest.warns(neo4j.PreviewWarning, match="Mutual TLS"): + with GraphDatabase.driver( + "bolt://localhost", client_certificate=provider + ) as driver: + assert driver._pool.pool_config.client_certificate is provider + + _T_NotificationMinimumSeverity = t.Union[ NotificationMinimumSeverity, te.Literal[