From 5a50e788c0657836776f85637963c5c868e81ec9 Mon Sep 17 00:00:00 2001 From: Robsdedude Date: Tue, 16 May 2023 18:20:34 +0200 Subject: [PATCH 1/3] Change ExpiringAuth.expires_in to expires_at Updating the preview feature to be more in line with other drivers and support more SSO flows, of which some specify token lifetime with an absolute time value while other flows use relative time. Now, there's both options: * absolute: `ExpiringAuth(some_auth, expires_at=...)` * relative: `ExpiringAuth(some_auth).expires_in(...)` (or doing it manually `ExpiringAuth(some_auth, expires_at=time.time() + expires_in)`) --- CHANGELOG.md | 4 ++- src/neo4j/_async/auth_management.py | 31 ++++++----------------- src/neo4j/_auth_management.py | 38 +++++++++++++++++++++++++---- src/neo4j/_sync/auth_management.py | 31 ++++++----------------- 4 files changed, 52 insertions(+), 52 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 31d8fac7f..b5931b531 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -3,7 +3,9 @@ See also https://github.com/neo4j/neo4j-python-driver/wiki for a full changelog. ## NEXT RELEASE -- No breaking or major changes. +- `neo4j.auth_management.ExpiringAuth`'s `expires_in` (in preview) was replaced + by `expires_at`, which is a unix timestamp. + You can use `ExpiringAuth(some_auth).expires_in(123)` instead. ## Version 5.8 diff --git a/src/neo4j/_async/auth_management.py b/src/neo4j/_async/auth_management.py index 6fe02d0a1..24e5166f1 100644 --- a/src/neo4j/_async/auth_management.py +++ b/src/neo4j/_async/auth_management.py @@ -28,6 +28,7 @@ from .._async_compat.concurrency import AsyncLock from .._auth_management import ( AsyncAuthManager, + expiring_auth_has_expired, ExpiringAuth, ) from .._meta import preview @@ -54,24 +55,8 @@ async def on_auth_expired(self, auth: _TAuth) -> None: pass -class _ExpiringAuthHolder: - def __init__(self, auth: ExpiringAuth) -> None: - self._auth = auth - self._expiry = None - if auth.expires_in is not None: - self._expiry = time.monotonic() + auth.expires_in - - @property - def auth(self) -> _TAuth: - return self._auth.auth - - def expired(self) -> bool: - if self._expiry is None: - return False - return time.monotonic() > self._expiry - class AsyncExpirationBasedAuthManager(AsyncAuthManager): - _current_auth: t.Optional[_ExpiringAuthHolder] + _current_auth: t.Optional[ExpiringAuth] _provider: t.Callable[[], t.Awaitable[ExpiringAuth]] _lock: AsyncLock @@ -85,16 +70,16 @@ def __init__( self._lock = AsyncLock() async def _refresh_auth(self): - self._current_auth = _ExpiringAuthHolder(await self._provider()) + self._current_auth = await self._provider() async def get_auth(self) -> _TAuth: async with self._lock: auth = self._current_auth - if auth is not None and not auth.expired(): - return auth.auth - log.debug("[ ] _: refreshing (time out)") - await self._refresh_auth() - assert self._current_auth is not None + if auth is None or expiring_auth_has_expired(auth): + log.debug("[ ] _: refreshing (time out)") + await self._refresh_auth() + auth = self._current_auth + assert auth is not None return self._current_auth.auth async def on_auth_expired(self, auth: _TAuth) -> None: diff --git a/src/neo4j/_auth_management.py b/src/neo4j/_auth_management.py index 69bfcec8c..9fc912f12 100644 --- a/src/neo4j/_auth_management.py +++ b/src/neo4j/_auth_management.py @@ -22,6 +22,7 @@ import abc +import time import typing as t from dataclasses import dataclass @@ -38,9 +39,11 @@ class ExpiringAuth: :meth:`.AsyncAuthManagers.temporal`. :param auth: The authentication information. - :param expires_in: The number of seconds until the authentication - information expires. If :data:`None`, the authentication information - is considered to not expire until the server explicitly indicates so. + :param expires_at: + Unix timestamp (seconds since 1970-01-01 00:00:00 UTC) + indicating when the authentication information expires. + If :data:`None`, the authentication information is considered to not + expire until the server explicitly indicates so. **This is a preview** (see :ref:`filter-warnings-ref`). It might be changed without following the deprecation policy. @@ -51,8 +54,33 @@ class ExpiringAuth: .. versionadded:: 5.8 """ - auth: _TAuth - expires_in: t.Optional[float] = None + auth: "_TAuth" + expires_at: t.Optional[float] = None + + def expires_in(self, seconds: float) -> "ExpiringAuth": + """Return a copy of this object with a new expiration time. + + This is a convenience method for creating an :class:`.ExpiringAuth` + for a relative expiration time ("expires in" instead of "expires at"). + + >>> import time, freezegun + >>> with freezegun.freeze_time("1970-01-01 00:00:00"): + ... ExpiringAuth(("user", "pass")).expires_in(60) + ExpiringAuth(auth=('user', 'pass'), expires_at=60.0) + >>> with freezegun.freeze_time("1970-01-01 00:00:00"): + ... ExpiringAuth(("user", "pass"), time.time() + 60) + ExpiringAuth(auth=('user', 'pass'), expires_at=60.0) + + :param seconds: + The number of seconds from now until the authentication information + expires. + """ + return ExpiringAuth(self.auth, time.time() + seconds) + + +def expiring_auth_has_expired(auth: ExpiringAuth) -> bool: + expires_at = auth.expires_at + return expires_at is not None and expires_at < time.time() class AuthManager(metaclass=abc.ABCMeta): diff --git a/src/neo4j/_sync/auth_management.py b/src/neo4j/_sync/auth_management.py index 4513860a7..7a4ced283 100644 --- a/src/neo4j/_sync/auth_management.py +++ b/src/neo4j/_sync/auth_management.py @@ -28,6 +28,7 @@ from .._async_compat.concurrency import Lock from .._auth_management import ( AuthManager, + expiring_auth_has_expired, ExpiringAuth, ) from .._meta import preview @@ -54,24 +55,8 @@ def on_auth_expired(self, auth: _TAuth) -> None: pass -class _ExpiringAuthHolder: - def __init__(self, auth: ExpiringAuth) -> None: - self._auth = auth - self._expiry = None - if auth.expires_in is not None: - self._expiry = time.monotonic() + auth.expires_in - - @property - def auth(self) -> _TAuth: - return self._auth.auth - - def expired(self) -> bool: - if self._expiry is None: - return False - return time.monotonic() > self._expiry - class ExpirationBasedAuthManager(AuthManager): - _current_auth: t.Optional[_ExpiringAuthHolder] + _current_auth: t.Optional[ExpiringAuth] _provider: t.Callable[[], t.Union[ExpiringAuth]] _lock: Lock @@ -85,16 +70,16 @@ def __init__( self._lock = Lock() def _refresh_auth(self): - self._current_auth = _ExpiringAuthHolder(self._provider()) + self._current_auth = self._provider() def get_auth(self) -> _TAuth: with self._lock: auth = self._current_auth - if auth is not None and not auth.expired(): - return auth.auth - log.debug("[ ] _: refreshing (time out)") - self._refresh_auth() - assert self._current_auth is not None + if auth is None or expiring_auth_has_expired(auth): + log.debug("[ ] _: refreshing (time out)") + self._refresh_auth() + auth = self._current_auth + assert auth is not None return self._current_auth.auth def on_auth_expired(self, auth: _TAuth) -> None: From 85744a963acae70744322a4dd64f498ef31d0bbd Mon Sep 17 00:00:00 2001 From: Robsdedude Date: Wed, 17 May 2023 10:47:19 +0200 Subject: [PATCH 2/3] Add more type safety to expiring auth provider --- src/neo4j/_async/auth_management.py | 7 ++++++- src/neo4j/_sync/auth_management.py | 7 ++++++- 2 files changed, 12 insertions(+), 2 deletions(-) diff --git a/src/neo4j/_async/auth_management.py b/src/neo4j/_async/auth_management.py index 24e5166f1..e423ddac9 100644 --- a/src/neo4j/_async/auth_management.py +++ b/src/neo4j/_async/auth_management.py @@ -71,6 +71,11 @@ def __init__( async def _refresh_auth(self): self._current_auth = await self._provider() + if self._current_auth is None: + raise TypeError( + "Auth provider function passed to expiration_based " + "AuthManager returned None, expected ExpiringAuth" + ) async def get_auth(self) -> _TAuth: async with self._lock: @@ -80,7 +85,7 @@ async def get_auth(self) -> _TAuth: await self._refresh_auth() auth = self._current_auth assert auth is not None - return self._current_auth.auth + return auth.auth async def on_auth_expired(self, auth: _TAuth) -> None: async with self._lock: diff --git a/src/neo4j/_sync/auth_management.py b/src/neo4j/_sync/auth_management.py index 7a4ced283..f54c70780 100644 --- a/src/neo4j/_sync/auth_management.py +++ b/src/neo4j/_sync/auth_management.py @@ -71,6 +71,11 @@ def __init__( def _refresh_auth(self): self._current_auth = self._provider() + if self._current_auth is None: + raise TypeError( + "Auth provider function passed to expiration_based " + "AuthManager returned None, expected ExpiringAuth" + ) def get_auth(self) -> _TAuth: with self._lock: @@ -80,7 +85,7 @@ def get_auth(self) -> _TAuth: self._refresh_auth() auth = self._current_auth assert auth is not None - return self._current_auth.auth + return auth.auth def on_auth_expired(self, auth: _TAuth) -> None: with self._lock: From 6ebf2b92b25dff189d63fb532ce4dfbd31b59a38 Mon Sep 17 00:00:00 2001 From: Robsdedude Date: Wed, 17 May 2023 14:23:25 +0200 Subject: [PATCH 3/3] Adjust unit tests and TestKit backend --- src/neo4j/_auth_management.py | 24 +++++++---- testkitbackend/_async/requests.py | 9 ++--- testkitbackend/_sync/requests.py | 9 ++--- tests/unit/async_/test_auth_manager.py | 55 +++++++++++++------------- tests/unit/sync/test_auth_manager.py | 55 +++++++++++++------------- 5 files changed, 78 insertions(+), 74 deletions(-) diff --git a/src/neo4j/_auth_management.py b/src/neo4j/_auth_management.py index 58b2ae44d..c4f7b259f 100644 --- a/src/neo4j/_auth_management.py +++ b/src/neo4j/_auth_management.py @@ -24,9 +24,13 @@ import abc import time import typing as t +import warnings from dataclasses import dataclass -from ._meta import preview +from ._meta import ( + preview, + PreviewWarning, +) from .api import _TAuth @@ -65,18 +69,22 @@ def expires_in(self, seconds: float) -> "ExpiringAuth": for a relative expiration time ("expires in" instead of "expires at"). >>> import time, freezegun - >>> with freezegun.freeze_time("1970-01-01 00:00:00"): - ... ExpiringAuth(("user", "pass")).expires_in(60) - ExpiringAuth(auth=('user', 'pass'), expires_at=60.0) - >>> with freezegun.freeze_time("1970-01-01 00:00:00"): - ... ExpiringAuth(("user", "pass"), time.time() + 60) - ExpiringAuth(auth=('user', 'pass'), expires_at=60.0) + >>> with freezegun.freeze_time("1970-01-01 00:00:40"): + ... ExpiringAuth(("user", "pass")).expires_in(2) + ExpiringAuth(auth=('user', 'pass'), expires_at=42.0) + >>> with freezegun.freeze_time("1970-01-01 00:00:40"): + ... ExpiringAuth(("user", "pass"), time.time() + 2) + ExpiringAuth(auth=('user', 'pass'), expires_at=42.0) :param seconds: The number of seconds from now until the authentication information expires. """ - return ExpiringAuth(self.auth, time.time() + seconds) + with warnings.catch_warnings(): + warnings.filterwarnings("ignore", + message=r"^Auth managers\b.*", + category=PreviewWarning) + return ExpiringAuth(self.auth, time.time() + seconds) def expiring_auth_has_expired(auth: ExpiringAuth) -> bool: diff --git a/testkitbackend/_async/requests.py b/testkitbackend/_async/requests.py index 6feffca6d..dfe891a1e 100644 --- a/testkitbackend/_async/requests.py +++ b/testkitbackend/_async/requests.py @@ -271,13 +271,12 @@ async def ExpirationBasedAuthTokenProviderCompleted(backend, data): "AuthTokenAndExpiration") temp_auth_data = temp_auth_data["data"] auth_token = fromtestkit.to_auth_token(temp_auth_data, "auth") - if temp_auth_data["expiresInMs"] is not None: - expires_in = temp_auth_data["expiresInMs"] / 1000 - else: - expires_in = None with warning_check(neo4j.PreviewWarning, "Auth managers are a preview feature."): - expiring_auth = ExpiringAuth(auth_token, expires_in) + expiring_auth = ExpiringAuth(auth_token) + if temp_auth_data["expiresInMs"] is not None: + expires_in = temp_auth_data["expiresInMs"] / 1000 + expiring_auth = expiring_auth.expires_in(expires_in) backend.expiring_auth_token_supplies[data["requestId"]] = expiring_auth diff --git a/testkitbackend/_sync/requests.py b/testkitbackend/_sync/requests.py index 3a23e3285..bd96f751d 100644 --- a/testkitbackend/_sync/requests.py +++ b/testkitbackend/_sync/requests.py @@ -271,13 +271,12 @@ def ExpirationBasedAuthTokenProviderCompleted(backend, data): "AuthTokenAndExpiration") temp_auth_data = temp_auth_data["data"] auth_token = fromtestkit.to_auth_token(temp_auth_data, "auth") - if temp_auth_data["expiresInMs"] is not None: - expires_in = temp_auth_data["expiresInMs"] / 1000 - else: - expires_in = None with warning_check(neo4j.PreviewWarning, "Auth managers are a preview feature."): - expiring_auth = ExpiringAuth(auth_token, expires_in) + expiring_auth = ExpiringAuth(auth_token) + if temp_auth_data["expiresInMs"] is not None: + expires_in = temp_auth_data["expiresInMs"] / 1000 + expiring_auth = expiring_auth.expires_in(expires_in) backend.expiring_auth_token_supplies[data["requestId"]] = expiring_auth diff --git a/tests/unit/async_/test_auth_manager.py b/tests/unit/async_/test_auth_manager.py index e29abacf1..65704204f 100644 --- a/tests/unit/async_/test_auth_manager.py +++ b/tests/unit/async_/test_auth_manager.py @@ -82,52 +82,51 @@ async def test_static_manager( @mark_async_test @pytest.mark.parametrize(("auth1", "auth2"), itertools.product(SAMPLE_AUTHS, repeat=2)) -@pytest.mark.parametrize("expires_in", (None, -1, 1., 1, 1000.)) +@pytest.mark.parametrize("expires_at", (None, .001, 1, 1000.)) async def test_expiration_based_manager_manual_expiry( auth1: t.Union[t.Tuple[str, str], Auth, None], auth2: t.Union[t.Tuple[str, str], Auth, None], - expires_in: t.Union[float, int], + expires_at: t.Optional[float], mocker ) -> None: - if expires_in is None or expires_in >= 0: - temporal_auth = expiring_auth(auth1, expires_in) - else: - temporal_auth = expiring_auth(auth1) - provider = mocker.AsyncMock(return_value=temporal_auth) - manager: AsyncAuthManager = expiration_based_auth_manager(provider) + with freeze_time("1970-01-01 00:00:00") as frozen_time: + assert isinstance(frozen_time, FrozenDateTimeFactory) + temporal_auth = expiring_auth(auth1, expires_at) + provider = mocker.AsyncMock(return_value=temporal_auth) + manager: AsyncAuthManager = expiration_based_auth_manager(provider) - provider.assert_not_called() - assert await manager.get_auth() is auth1 - provider.assert_awaited_once() - provider.reset_mock() + provider.assert_not_called() + assert await manager.get_auth() is auth1 + provider.assert_awaited_once() + provider.reset_mock() - provider.return_value = expiring_auth(auth2) + provider.return_value = expiring_auth(auth2) - await manager.on_auth_expired(("something", "else")) - assert await manager.get_auth() is auth1 - provider.assert_not_called() + await manager.on_auth_expired(("something", "else")) + assert await manager.get_auth() is auth1 + provider.assert_not_called() - await manager.on_auth_expired(auth1) - provider.assert_awaited_once() - provider.reset_mock() - assert await manager.get_auth() is auth2 - provider.assert_not_called() + await manager.on_auth_expired(auth1) + provider.assert_awaited_once() + provider.reset_mock() + assert await manager.get_auth() is auth2 + provider.assert_not_called() @mark_async_test @pytest.mark.parametrize(("auth1", "auth2"), itertools.product(SAMPLE_AUTHS, repeat=2)) -@pytest.mark.parametrize("expires_in", (None, -1, 1., 1, 1000.)) +@pytest.mark.parametrize("expires_at", (None, -1, 1., 1, 1000.)) async def test_expiration_based_manager_time_expiry( auth1: t.Union[t.Tuple[str, str], Auth, None], auth2: t.Union[t.Tuple[str, str], Auth, None], - expires_in: t.Union[float, int, None], + expires_at: t.Optional[float], mocker ) -> None: - with freeze_time() as frozen_time: + with freeze_time("1970-01-01 00:00:00") as frozen_time: assert isinstance(frozen_time, FrozenDateTimeFactory) - if expires_in is None or expires_in >= 0: - temporal_auth = expiring_auth(auth1, expires_in) + if expires_at is None or expires_at >= 0: + temporal_auth = expiring_auth(auth1, expires_at) else: temporal_auth = expiring_auth(auth1) provider = mocker.AsyncMock(return_value=temporal_auth) @@ -140,12 +139,12 @@ async def test_expiration_based_manager_time_expiry( provider.return_value = expiring_auth(auth2) - if expires_in is None or expires_in < 0: + if expires_at is None or expires_at < 0: frozen_time.tick(1_000_000) assert await manager.get_auth() is auth1 provider.assert_not_called() else: - frozen_time.tick(expires_in - 0.000001) + frozen_time.tick(expires_at - 0.000001) assert await manager.get_auth() is auth1 provider.assert_not_called() frozen_time.tick(0.000002) diff --git a/tests/unit/sync/test_auth_manager.py b/tests/unit/sync/test_auth_manager.py index 15d8b5982..634fc42a6 100644 --- a/tests/unit/sync/test_auth_manager.py +++ b/tests/unit/sync/test_auth_manager.py @@ -82,52 +82,51 @@ def test_static_manager( @mark_sync_test @pytest.mark.parametrize(("auth1", "auth2"), itertools.product(SAMPLE_AUTHS, repeat=2)) -@pytest.mark.parametrize("expires_in", (None, -1, 1., 1, 1000.)) +@pytest.mark.parametrize("expires_at", (None, .001, 1, 1000.)) def test_expiration_based_manager_manual_expiry( auth1: t.Union[t.Tuple[str, str], Auth, None], auth2: t.Union[t.Tuple[str, str], Auth, None], - expires_in: t.Union[float, int], + expires_at: t.Optional[float], mocker ) -> None: - if expires_in is None or expires_in >= 0: - temporal_auth = expiring_auth(auth1, expires_in) - else: - temporal_auth = expiring_auth(auth1) - provider = mocker.MagicMock(return_value=temporal_auth) - manager: AuthManager = expiration_based_auth_manager(provider) + with freeze_time("1970-01-01 00:00:00") as frozen_time: + assert isinstance(frozen_time, FrozenDateTimeFactory) + temporal_auth = expiring_auth(auth1, expires_at) + provider = mocker.MagicMock(return_value=temporal_auth) + manager: AuthManager = expiration_based_auth_manager(provider) - provider.assert_not_called() - assert manager.get_auth() is auth1 - provider.assert_called_once() - provider.reset_mock() + provider.assert_not_called() + assert manager.get_auth() is auth1 + provider.assert_called_once() + provider.reset_mock() - provider.return_value = expiring_auth(auth2) + provider.return_value = expiring_auth(auth2) - manager.on_auth_expired(("something", "else")) - assert manager.get_auth() is auth1 - provider.assert_not_called() + manager.on_auth_expired(("something", "else")) + assert manager.get_auth() is auth1 + provider.assert_not_called() - manager.on_auth_expired(auth1) - provider.assert_called_once() - provider.reset_mock() - assert manager.get_auth() is auth2 - provider.assert_not_called() + manager.on_auth_expired(auth1) + provider.assert_called_once() + provider.reset_mock() + assert manager.get_auth() is auth2 + provider.assert_not_called() @mark_sync_test @pytest.mark.parametrize(("auth1", "auth2"), itertools.product(SAMPLE_AUTHS, repeat=2)) -@pytest.mark.parametrize("expires_in", (None, -1, 1., 1, 1000.)) +@pytest.mark.parametrize("expires_at", (None, -1, 1., 1, 1000.)) def test_expiration_based_manager_time_expiry( auth1: t.Union[t.Tuple[str, str], Auth, None], auth2: t.Union[t.Tuple[str, str], Auth, None], - expires_in: t.Union[float, int, None], + expires_at: t.Optional[float], mocker ) -> None: - with freeze_time() as frozen_time: + with freeze_time("1970-01-01 00:00:00") as frozen_time: assert isinstance(frozen_time, FrozenDateTimeFactory) - if expires_in is None or expires_in >= 0: - temporal_auth = expiring_auth(auth1, expires_in) + if expires_at is None or expires_at >= 0: + temporal_auth = expiring_auth(auth1, expires_at) else: temporal_auth = expiring_auth(auth1) provider = mocker.MagicMock(return_value=temporal_auth) @@ -140,12 +139,12 @@ def test_expiration_based_manager_time_expiry( provider.return_value = expiring_auth(auth2) - if expires_in is None or expires_in < 0: + if expires_at is None or expires_at < 0: frozen_time.tick(1_000_000) assert manager.get_auth() is auth1 provider.assert_not_called() else: - frozen_time.tick(expires_in - 0.000001) + frozen_time.tick(expires_at - 0.000001) assert manager.get_auth() is auth1 provider.assert_not_called() frozen_time.tick(0.000002)