diff --git a/CHANGELOG.md b/CHANGELOG.md index 31d8fac7f..b5931b531 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -3,7 +3,9 @@ See also https://github.com/neo4j/neo4j-python-driver/wiki for a full changelog. ## NEXT RELEASE -- No breaking or major changes. +- `neo4j.auth_management.ExpiringAuth`'s `expires_in` (in preview) was replaced + by `expires_at`, which is a unix timestamp. + You can use `ExpiringAuth(some_auth).expires_in(123)` instead. ## Version 5.8 diff --git a/src/neo4j/_async/auth_management.py b/src/neo4j/_async/auth_management.py index 6fe02d0a1..e423ddac9 100644 --- a/src/neo4j/_async/auth_management.py +++ b/src/neo4j/_async/auth_management.py @@ -28,6 +28,7 @@ from .._async_compat.concurrency import AsyncLock from .._auth_management import ( AsyncAuthManager, + expiring_auth_has_expired, ExpiringAuth, ) from .._meta import preview @@ -54,24 +55,8 @@ async def on_auth_expired(self, auth: _TAuth) -> None: pass -class _ExpiringAuthHolder: - def __init__(self, auth: ExpiringAuth) -> None: - self._auth = auth - self._expiry = None - if auth.expires_in is not None: - self._expiry = time.monotonic() + auth.expires_in - - @property - def auth(self) -> _TAuth: - return self._auth.auth - - def expired(self) -> bool: - if self._expiry is None: - return False - return time.monotonic() > self._expiry - class AsyncExpirationBasedAuthManager(AsyncAuthManager): - _current_auth: t.Optional[_ExpiringAuthHolder] + _current_auth: t.Optional[ExpiringAuth] _provider: t.Callable[[], t.Awaitable[ExpiringAuth]] _lock: AsyncLock @@ -85,17 +70,22 @@ def __init__( self._lock = AsyncLock() async def _refresh_auth(self): - self._current_auth = _ExpiringAuthHolder(await self._provider()) + self._current_auth = await self._provider() + if self._current_auth is None: + raise TypeError( + "Auth provider function passed to expiration_based " + "AuthManager returned None, expected ExpiringAuth" + ) async def get_auth(self) -> _TAuth: async with self._lock: auth = self._current_auth - if auth is not None and not auth.expired(): - return auth.auth - log.debug("[ ] _: refreshing (time out)") - await self._refresh_auth() - assert self._current_auth is not None - return self._current_auth.auth + if auth is None or expiring_auth_has_expired(auth): + log.debug("[ ] _: refreshing (time out)") + await self._refresh_auth() + auth = self._current_auth + assert auth is not None + return auth.auth async def on_auth_expired(self, auth: _TAuth) -> None: async with self._lock: diff --git a/src/neo4j/_auth_management.py b/src/neo4j/_auth_management.py index 14ff89db0..c4f7b259f 100644 --- a/src/neo4j/_auth_management.py +++ b/src/neo4j/_auth_management.py @@ -22,10 +22,15 @@ import abc +import time import typing as t +import warnings from dataclasses import dataclass -from ._meta import preview +from ._meta import ( + preview, + PreviewWarning, +) from .api import _TAuth @@ -38,9 +43,11 @@ class ExpiringAuth: :meth:`.AsyncAuthManagers.expiration_based`. :param auth: The authentication information. - :param expires_in: The number of seconds until the authentication - information expires. If :data:`None`, the authentication information - is considered to not expire until the server explicitly indicates so. + :param expires_at: + Unix timestamp (seconds since 1970-01-01 00:00:00 UTC) + indicating when the authentication information expires. + If :data:`None`, the authentication information is considered to not + expire until the server explicitly indicates so. **This is a preview** (see :ref:`filter-warnings-ref`). It might be changed without following the deprecation policy. @@ -52,8 +59,37 @@ class ExpiringAuth: .. versionadded:: 5.8 """ - auth: _TAuth - expires_in: t.Optional[float] = None + auth: "_TAuth" + expires_at: t.Optional[float] = None + + def expires_in(self, seconds: float) -> "ExpiringAuth": + """Return a copy of this object with a new expiration time. + + This is a convenience method for creating an :class:`.ExpiringAuth` + for a relative expiration time ("expires in" instead of "expires at"). + + >>> import time, freezegun + >>> with freezegun.freeze_time("1970-01-01 00:00:40"): + ... ExpiringAuth(("user", "pass")).expires_in(2) + ExpiringAuth(auth=('user', 'pass'), expires_at=42.0) + >>> with freezegun.freeze_time("1970-01-01 00:00:40"): + ... ExpiringAuth(("user", "pass"), time.time() + 2) + ExpiringAuth(auth=('user', 'pass'), expires_at=42.0) + + :param seconds: + The number of seconds from now until the authentication information + expires. + """ + with warnings.catch_warnings(): + warnings.filterwarnings("ignore", + message=r"^Auth managers\b.*", + category=PreviewWarning) + return ExpiringAuth(self.auth, time.time() + seconds) + + +def expiring_auth_has_expired(auth: ExpiringAuth) -> bool: + expires_at = auth.expires_at + return expires_at is not None and expires_at < time.time() class AuthManager(metaclass=abc.ABCMeta): diff --git a/src/neo4j/_sync/auth_management.py b/src/neo4j/_sync/auth_management.py index 4513860a7..f54c70780 100644 --- a/src/neo4j/_sync/auth_management.py +++ b/src/neo4j/_sync/auth_management.py @@ -28,6 +28,7 @@ from .._async_compat.concurrency import Lock from .._auth_management import ( AuthManager, + expiring_auth_has_expired, ExpiringAuth, ) from .._meta import preview @@ -54,24 +55,8 @@ def on_auth_expired(self, auth: _TAuth) -> None: pass -class _ExpiringAuthHolder: - def __init__(self, auth: ExpiringAuth) -> None: - self._auth = auth - self._expiry = None - if auth.expires_in is not None: - self._expiry = time.monotonic() + auth.expires_in - - @property - def auth(self) -> _TAuth: - return self._auth.auth - - def expired(self) -> bool: - if self._expiry is None: - return False - return time.monotonic() > self._expiry - class ExpirationBasedAuthManager(AuthManager): - _current_auth: t.Optional[_ExpiringAuthHolder] + _current_auth: t.Optional[ExpiringAuth] _provider: t.Callable[[], t.Union[ExpiringAuth]] _lock: Lock @@ -85,17 +70,22 @@ def __init__( self._lock = Lock() def _refresh_auth(self): - self._current_auth = _ExpiringAuthHolder(self._provider()) + self._current_auth = self._provider() + if self._current_auth is None: + raise TypeError( + "Auth provider function passed to expiration_based " + "AuthManager returned None, expected ExpiringAuth" + ) def get_auth(self) -> _TAuth: with self._lock: auth = self._current_auth - if auth is not None and not auth.expired(): - return auth.auth - log.debug("[ ] _: refreshing (time out)") - self._refresh_auth() - assert self._current_auth is not None - return self._current_auth.auth + if auth is None or expiring_auth_has_expired(auth): + log.debug("[ ] _: refreshing (time out)") + self._refresh_auth() + auth = self._current_auth + assert auth is not None + return auth.auth def on_auth_expired(self, auth: _TAuth) -> None: with self._lock: diff --git a/testkitbackend/_async/requests.py b/testkitbackend/_async/requests.py index 6feffca6d..dfe891a1e 100644 --- a/testkitbackend/_async/requests.py +++ b/testkitbackend/_async/requests.py @@ -271,13 +271,12 @@ async def ExpirationBasedAuthTokenProviderCompleted(backend, data): "AuthTokenAndExpiration") temp_auth_data = temp_auth_data["data"] auth_token = fromtestkit.to_auth_token(temp_auth_data, "auth") - if temp_auth_data["expiresInMs"] is not None: - expires_in = temp_auth_data["expiresInMs"] / 1000 - else: - expires_in = None with warning_check(neo4j.PreviewWarning, "Auth managers are a preview feature."): - expiring_auth = ExpiringAuth(auth_token, expires_in) + expiring_auth = ExpiringAuth(auth_token) + if temp_auth_data["expiresInMs"] is not None: + expires_in = temp_auth_data["expiresInMs"] / 1000 + expiring_auth = expiring_auth.expires_in(expires_in) backend.expiring_auth_token_supplies[data["requestId"]] = expiring_auth diff --git a/testkitbackend/_sync/requests.py b/testkitbackend/_sync/requests.py index 3a23e3285..bd96f751d 100644 --- a/testkitbackend/_sync/requests.py +++ b/testkitbackend/_sync/requests.py @@ -271,13 +271,12 @@ def ExpirationBasedAuthTokenProviderCompleted(backend, data): "AuthTokenAndExpiration") temp_auth_data = temp_auth_data["data"] auth_token = fromtestkit.to_auth_token(temp_auth_data, "auth") - if temp_auth_data["expiresInMs"] is not None: - expires_in = temp_auth_data["expiresInMs"] / 1000 - else: - expires_in = None with warning_check(neo4j.PreviewWarning, "Auth managers are a preview feature."): - expiring_auth = ExpiringAuth(auth_token, expires_in) + expiring_auth = ExpiringAuth(auth_token) + if temp_auth_data["expiresInMs"] is not None: + expires_in = temp_auth_data["expiresInMs"] / 1000 + expiring_auth = expiring_auth.expires_in(expires_in) backend.expiring_auth_token_supplies[data["requestId"]] = expiring_auth diff --git a/tests/unit/async_/test_auth_manager.py b/tests/unit/async_/test_auth_manager.py index e29abacf1..65704204f 100644 --- a/tests/unit/async_/test_auth_manager.py +++ b/tests/unit/async_/test_auth_manager.py @@ -82,52 +82,51 @@ async def test_static_manager( @mark_async_test @pytest.mark.parametrize(("auth1", "auth2"), itertools.product(SAMPLE_AUTHS, repeat=2)) -@pytest.mark.parametrize("expires_in", (None, -1, 1., 1, 1000.)) +@pytest.mark.parametrize("expires_at", (None, .001, 1, 1000.)) async def test_expiration_based_manager_manual_expiry( auth1: t.Union[t.Tuple[str, str], Auth, None], auth2: t.Union[t.Tuple[str, str], Auth, None], - expires_in: t.Union[float, int], + expires_at: t.Optional[float], mocker ) -> None: - if expires_in is None or expires_in >= 0: - temporal_auth = expiring_auth(auth1, expires_in) - else: - temporal_auth = expiring_auth(auth1) - provider = mocker.AsyncMock(return_value=temporal_auth) - manager: AsyncAuthManager = expiration_based_auth_manager(provider) + with freeze_time("1970-01-01 00:00:00") as frozen_time: + assert isinstance(frozen_time, FrozenDateTimeFactory) + temporal_auth = expiring_auth(auth1, expires_at) + provider = mocker.AsyncMock(return_value=temporal_auth) + manager: AsyncAuthManager = expiration_based_auth_manager(provider) - provider.assert_not_called() - assert await manager.get_auth() is auth1 - provider.assert_awaited_once() - provider.reset_mock() + provider.assert_not_called() + assert await manager.get_auth() is auth1 + provider.assert_awaited_once() + provider.reset_mock() - provider.return_value = expiring_auth(auth2) + provider.return_value = expiring_auth(auth2) - await manager.on_auth_expired(("something", "else")) - assert await manager.get_auth() is auth1 - provider.assert_not_called() + await manager.on_auth_expired(("something", "else")) + assert await manager.get_auth() is auth1 + provider.assert_not_called() - await manager.on_auth_expired(auth1) - provider.assert_awaited_once() - provider.reset_mock() - assert await manager.get_auth() is auth2 - provider.assert_not_called() + await manager.on_auth_expired(auth1) + provider.assert_awaited_once() + provider.reset_mock() + assert await manager.get_auth() is auth2 + provider.assert_not_called() @mark_async_test @pytest.mark.parametrize(("auth1", "auth2"), itertools.product(SAMPLE_AUTHS, repeat=2)) -@pytest.mark.parametrize("expires_in", (None, -1, 1., 1, 1000.)) +@pytest.mark.parametrize("expires_at", (None, -1, 1., 1, 1000.)) async def test_expiration_based_manager_time_expiry( auth1: t.Union[t.Tuple[str, str], Auth, None], auth2: t.Union[t.Tuple[str, str], Auth, None], - expires_in: t.Union[float, int, None], + expires_at: t.Optional[float], mocker ) -> None: - with freeze_time() as frozen_time: + with freeze_time("1970-01-01 00:00:00") as frozen_time: assert isinstance(frozen_time, FrozenDateTimeFactory) - if expires_in is None or expires_in >= 0: - temporal_auth = expiring_auth(auth1, expires_in) + if expires_at is None or expires_at >= 0: + temporal_auth = expiring_auth(auth1, expires_at) else: temporal_auth = expiring_auth(auth1) provider = mocker.AsyncMock(return_value=temporal_auth) @@ -140,12 +139,12 @@ async def test_expiration_based_manager_time_expiry( provider.return_value = expiring_auth(auth2) - if expires_in is None or expires_in < 0: + if expires_at is None or expires_at < 0: frozen_time.tick(1_000_000) assert await manager.get_auth() is auth1 provider.assert_not_called() else: - frozen_time.tick(expires_in - 0.000001) + frozen_time.tick(expires_at - 0.000001) assert await manager.get_auth() is auth1 provider.assert_not_called() frozen_time.tick(0.000002) diff --git a/tests/unit/sync/test_auth_manager.py b/tests/unit/sync/test_auth_manager.py index 15d8b5982..634fc42a6 100644 --- a/tests/unit/sync/test_auth_manager.py +++ b/tests/unit/sync/test_auth_manager.py @@ -82,52 +82,51 @@ def test_static_manager( @mark_sync_test @pytest.mark.parametrize(("auth1", "auth2"), itertools.product(SAMPLE_AUTHS, repeat=2)) -@pytest.mark.parametrize("expires_in", (None, -1, 1., 1, 1000.)) +@pytest.mark.parametrize("expires_at", (None, .001, 1, 1000.)) def test_expiration_based_manager_manual_expiry( auth1: t.Union[t.Tuple[str, str], Auth, None], auth2: t.Union[t.Tuple[str, str], Auth, None], - expires_in: t.Union[float, int], + expires_at: t.Optional[float], mocker ) -> None: - if expires_in is None or expires_in >= 0: - temporal_auth = expiring_auth(auth1, expires_in) - else: - temporal_auth = expiring_auth(auth1) - provider = mocker.MagicMock(return_value=temporal_auth) - manager: AuthManager = expiration_based_auth_manager(provider) + with freeze_time("1970-01-01 00:00:00") as frozen_time: + assert isinstance(frozen_time, FrozenDateTimeFactory) + temporal_auth = expiring_auth(auth1, expires_at) + provider = mocker.MagicMock(return_value=temporal_auth) + manager: AuthManager = expiration_based_auth_manager(provider) - provider.assert_not_called() - assert manager.get_auth() is auth1 - provider.assert_called_once() - provider.reset_mock() + provider.assert_not_called() + assert manager.get_auth() is auth1 + provider.assert_called_once() + provider.reset_mock() - provider.return_value = expiring_auth(auth2) + provider.return_value = expiring_auth(auth2) - manager.on_auth_expired(("something", "else")) - assert manager.get_auth() is auth1 - provider.assert_not_called() + manager.on_auth_expired(("something", "else")) + assert manager.get_auth() is auth1 + provider.assert_not_called() - manager.on_auth_expired(auth1) - provider.assert_called_once() - provider.reset_mock() - assert manager.get_auth() is auth2 - provider.assert_not_called() + manager.on_auth_expired(auth1) + provider.assert_called_once() + provider.reset_mock() + assert manager.get_auth() is auth2 + provider.assert_not_called() @mark_sync_test @pytest.mark.parametrize(("auth1", "auth2"), itertools.product(SAMPLE_AUTHS, repeat=2)) -@pytest.mark.parametrize("expires_in", (None, -1, 1., 1, 1000.)) +@pytest.mark.parametrize("expires_at", (None, -1, 1., 1, 1000.)) def test_expiration_based_manager_time_expiry( auth1: t.Union[t.Tuple[str, str], Auth, None], auth2: t.Union[t.Tuple[str, str], Auth, None], - expires_in: t.Union[float, int, None], + expires_at: t.Optional[float], mocker ) -> None: - with freeze_time() as frozen_time: + with freeze_time("1970-01-01 00:00:00") as frozen_time: assert isinstance(frozen_time, FrozenDateTimeFactory) - if expires_in is None or expires_in >= 0: - temporal_auth = expiring_auth(auth1, expires_in) + if expires_at is None or expires_at >= 0: + temporal_auth = expiring_auth(auth1, expires_at) else: temporal_auth = expiring_auth(auth1) provider = mocker.MagicMock(return_value=temporal_auth) @@ -140,12 +139,12 @@ def test_expiration_based_manager_time_expiry( provider.return_value = expiring_auth(auth2) - if expires_in is None or expires_in < 0: + if expires_at is None or expires_at < 0: frozen_time.tick(1_000_000) assert manager.get_auth() is auth1 provider.assert_not_called() else: - frozen_time.tick(expires_in - 0.000001) + frozen_time.tick(expires_at - 0.000001) assert manager.get_auth() is auth1 provider.assert_not_called() frozen_time.tick(0.000002)