Skip to content

Commit 3f5074c

Browse files
authored
Change ExpiringAuth.expires_in to expires_at (#928)
Updating the preview feature to be more in line with other drivers and support more SSO flows, of which some specify token lifetime with an absolute time value while other flows use relative time. Now, there's both options: * absolute: `ExpiringAuth(some_auth, expires_at=...)` * relative: `ExpiringAuth(some_auth).expires_in(...)` (or doing it manually `ExpiringAuth(some_auth, expires_at=time.time() + expires_in)`) Add more type safety to expiring auth provider (`None`-check) Adjust unit tests and TestKit backend
1 parent a2e2699 commit 3f5074c

File tree

8 files changed

+135
-121
lines changed

8 files changed

+135
-121
lines changed

CHANGELOG.md

+3-1
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,9 @@
33
See also https://github.com/neo4j/neo4j-python-driver/wiki for a full changelog.
44

55
## NEXT RELEASE
6-
- No breaking or major changes.
6+
- `neo4j.auth_management.ExpiringAuth`'s `expires_in` (in preview) was replaced
7+
by `expires_at`, which is a unix timestamp.
8+
You can use `ExpiringAuth(some_auth).expires_in(123)` instead.
79

810

911
## Version 5.8

src/neo4j/_async/auth_management.py

+14-24
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@
2828
from .._async_compat.concurrency import AsyncLock
2929
from .._auth_management import (
3030
AsyncAuthManager,
31+
expiring_auth_has_expired,
3132
ExpiringAuth,
3233
)
3334
from .._meta import preview
@@ -54,24 +55,8 @@ async def on_auth_expired(self, auth: _TAuth) -> None:
5455
pass
5556

5657

57-
class _ExpiringAuthHolder:
58-
def __init__(self, auth: ExpiringAuth) -> None:
59-
self._auth = auth
60-
self._expiry = None
61-
if auth.expires_in is not None:
62-
self._expiry = time.monotonic() + auth.expires_in
63-
64-
@property
65-
def auth(self) -> _TAuth:
66-
return self._auth.auth
67-
68-
def expired(self) -> bool:
69-
if self._expiry is None:
70-
return False
71-
return time.monotonic() > self._expiry
72-
7358
class AsyncExpirationBasedAuthManager(AsyncAuthManager):
74-
_current_auth: t.Optional[_ExpiringAuthHolder]
59+
_current_auth: t.Optional[ExpiringAuth]
7560
_provider: t.Callable[[], t.Awaitable[ExpiringAuth]]
7661
_lock: AsyncLock
7762

@@ -85,17 +70,22 @@ def __init__(
8570
self._lock = AsyncLock()
8671

8772
async def _refresh_auth(self):
88-
self._current_auth = _ExpiringAuthHolder(await self._provider())
73+
self._current_auth = await self._provider()
74+
if self._current_auth is None:
75+
raise TypeError(
76+
"Auth provider function passed to expiration_based "
77+
"AuthManager returned None, expected ExpiringAuth"
78+
)
8979

9080
async def get_auth(self) -> _TAuth:
9181
async with self._lock:
9282
auth = self._current_auth
93-
if auth is not None and not auth.expired():
94-
return auth.auth
95-
log.debug("[ ] _: <TEMPORAL AUTH> refreshing (time out)")
96-
await self._refresh_auth()
97-
assert self._current_auth is not None
98-
return self._current_auth.auth
83+
if auth is None or expiring_auth_has_expired(auth):
84+
log.debug("[ ] _: <TEMPORAL AUTH> refreshing (time out)")
85+
await self._refresh_auth()
86+
auth = self._current_auth
87+
assert auth is not None
88+
return auth.auth
9989

10090
async def on_auth_expired(self, auth: _TAuth) -> None:
10191
async with self._lock:

src/neo4j/_auth_management.py

+42-6
Original file line numberDiff line numberDiff line change
@@ -22,10 +22,15 @@
2222

2323

2424
import abc
25+
import time
2526
import typing as t
27+
import warnings
2628
from dataclasses import dataclass
2729

28-
from ._meta import preview
30+
from ._meta import (
31+
preview,
32+
PreviewWarning,
33+
)
2934
from .api import _TAuth
3035

3136

@@ -38,9 +43,11 @@ class ExpiringAuth:
3843
:meth:`.AsyncAuthManagers.expiration_based`.
3944
4045
:param auth: The authentication information.
41-
:param expires_in: The number of seconds until the authentication
42-
information expires. If :data:`None`, the authentication information
43-
is considered to not expire until the server explicitly indicates so.
46+
:param expires_at:
47+
Unix timestamp (seconds since 1970-01-01 00:00:00 UTC)
48+
indicating when the authentication information expires.
49+
If :data:`None`, the authentication information is considered to not
50+
expire until the server explicitly indicates so.
4451
4552
**This is a preview** (see :ref:`filter-warnings-ref`).
4653
It might be changed without following the deprecation policy.
@@ -52,8 +59,37 @@ class ExpiringAuth:
5259
5360
.. versionadded:: 5.8
5461
"""
55-
auth: _TAuth
56-
expires_in: t.Optional[float] = None
62+
auth: "_TAuth"
63+
expires_at: t.Optional[float] = None
64+
65+
def expires_in(self, seconds: float) -> "ExpiringAuth":
66+
"""Return a copy of this object with a new expiration time.
67+
68+
This is a convenience method for creating an :class:`.ExpiringAuth`
69+
for a relative expiration time ("expires in" instead of "expires at").
70+
71+
>>> import time, freezegun
72+
>>> with freezegun.freeze_time("1970-01-01 00:00:40"):
73+
... ExpiringAuth(("user", "pass")).expires_in(2)
74+
ExpiringAuth(auth=('user', 'pass'), expires_at=42.0)
75+
>>> with freezegun.freeze_time("1970-01-01 00:00:40"):
76+
... ExpiringAuth(("user", "pass"), time.time() + 2)
77+
ExpiringAuth(auth=('user', 'pass'), expires_at=42.0)
78+
79+
:param seconds:
80+
The number of seconds from now until the authentication information
81+
expires.
82+
"""
83+
with warnings.catch_warnings():
84+
warnings.filterwarnings("ignore",
85+
message=r"^Auth managers\b.*",
86+
category=PreviewWarning)
87+
return ExpiringAuth(self.auth, time.time() + seconds)
88+
89+
90+
def expiring_auth_has_expired(auth: ExpiringAuth) -> bool:
91+
expires_at = auth.expires_at
92+
return expires_at is not None and expires_at < time.time()
5793

5894

5995
class AuthManager(metaclass=abc.ABCMeta):

src/neo4j/_sync/auth_management.py

+14-24
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@
2828
from .._async_compat.concurrency import Lock
2929
from .._auth_management import (
3030
AuthManager,
31+
expiring_auth_has_expired,
3132
ExpiringAuth,
3233
)
3334
from .._meta import preview
@@ -54,24 +55,8 @@ def on_auth_expired(self, auth: _TAuth) -> None:
5455
pass
5556

5657

57-
class _ExpiringAuthHolder:
58-
def __init__(self, auth: ExpiringAuth) -> None:
59-
self._auth = auth
60-
self._expiry = None
61-
if auth.expires_in is not None:
62-
self._expiry = time.monotonic() + auth.expires_in
63-
64-
@property
65-
def auth(self) -> _TAuth:
66-
return self._auth.auth
67-
68-
def expired(self) -> bool:
69-
if self._expiry is None:
70-
return False
71-
return time.monotonic() > self._expiry
72-
7358
class ExpirationBasedAuthManager(AuthManager):
74-
_current_auth: t.Optional[_ExpiringAuthHolder]
59+
_current_auth: t.Optional[ExpiringAuth]
7560
_provider: t.Callable[[], t.Union[ExpiringAuth]]
7661
_lock: Lock
7762

@@ -85,17 +70,22 @@ def __init__(
8570
self._lock = Lock()
8671

8772
def _refresh_auth(self):
88-
self._current_auth = _ExpiringAuthHolder(self._provider())
73+
self._current_auth = self._provider()
74+
if self._current_auth is None:
75+
raise TypeError(
76+
"Auth provider function passed to expiration_based "
77+
"AuthManager returned None, expected ExpiringAuth"
78+
)
8979

9080
def get_auth(self) -> _TAuth:
9181
with self._lock:
9282
auth = self._current_auth
93-
if auth is not None and not auth.expired():
94-
return auth.auth
95-
log.debug("[ ] _: <TEMPORAL AUTH> refreshing (time out)")
96-
self._refresh_auth()
97-
assert self._current_auth is not None
98-
return self._current_auth.auth
83+
if auth is None or expiring_auth_has_expired(auth):
84+
log.debug("[ ] _: <TEMPORAL AUTH> refreshing (time out)")
85+
self._refresh_auth()
86+
auth = self._current_auth
87+
assert auth is not None
88+
return auth.auth
9989

10090
def on_auth_expired(self, auth: _TAuth) -> None:
10191
with self._lock:

testkitbackend/_async/requests.py

+4-5
Original file line numberDiff line numberDiff line change
@@ -271,13 +271,12 @@ async def ExpirationBasedAuthTokenProviderCompleted(backend, data):
271271
"AuthTokenAndExpiration")
272272
temp_auth_data = temp_auth_data["data"]
273273
auth_token = fromtestkit.to_auth_token(temp_auth_data, "auth")
274-
if temp_auth_data["expiresInMs"] is not None:
275-
expires_in = temp_auth_data["expiresInMs"] / 1000
276-
else:
277-
expires_in = None
278274
with warning_check(neo4j.PreviewWarning,
279275
"Auth managers are a preview feature."):
280-
expiring_auth = ExpiringAuth(auth_token, expires_in)
276+
expiring_auth = ExpiringAuth(auth_token)
277+
if temp_auth_data["expiresInMs"] is not None:
278+
expires_in = temp_auth_data["expiresInMs"] / 1000
279+
expiring_auth = expiring_auth.expires_in(expires_in)
281280

282281
backend.expiring_auth_token_supplies[data["requestId"]] = expiring_auth
283282

testkitbackend/_sync/requests.py

+4-5
Original file line numberDiff line numberDiff line change
@@ -271,13 +271,12 @@ def ExpirationBasedAuthTokenProviderCompleted(backend, data):
271271
"AuthTokenAndExpiration")
272272
temp_auth_data = temp_auth_data["data"]
273273
auth_token = fromtestkit.to_auth_token(temp_auth_data, "auth")
274-
if temp_auth_data["expiresInMs"] is not None:
275-
expires_in = temp_auth_data["expiresInMs"] / 1000
276-
else:
277-
expires_in = None
278274
with warning_check(neo4j.PreviewWarning,
279275
"Auth managers are a preview feature."):
280-
expiring_auth = ExpiringAuth(auth_token, expires_in)
276+
expiring_auth = ExpiringAuth(auth_token)
277+
if temp_auth_data["expiresInMs"] is not None:
278+
expires_in = temp_auth_data["expiresInMs"] / 1000
279+
expiring_auth = expiring_auth.expires_in(expires_in)
281280

282281
backend.expiring_auth_token_supplies[data["requestId"]] = expiring_auth
283282

tests/unit/async_/test_auth_manager.py

+27-28
Original file line numberDiff line numberDiff line change
@@ -82,52 +82,51 @@ async def test_static_manager(
8282
@mark_async_test
8383
@pytest.mark.parametrize(("auth1", "auth2"),
8484
itertools.product(SAMPLE_AUTHS, repeat=2))
85-
@pytest.mark.parametrize("expires_in", (None, -1, 1., 1, 1000.))
85+
@pytest.mark.parametrize("expires_at", (None, .001, 1, 1000.))
8686
async def test_expiration_based_manager_manual_expiry(
8787
auth1: t.Union[t.Tuple[str, str], Auth, None],
8888
auth2: t.Union[t.Tuple[str, str], Auth, None],
89-
expires_in: t.Union[float, int],
89+
expires_at: t.Optional[float],
9090
mocker
9191
) -> None:
92-
if expires_in is None or expires_in >= 0:
93-
temporal_auth = expiring_auth(auth1, expires_in)
94-
else:
95-
temporal_auth = expiring_auth(auth1)
96-
provider = mocker.AsyncMock(return_value=temporal_auth)
97-
manager: AsyncAuthManager = expiration_based_auth_manager(provider)
92+
with freeze_time("1970-01-01 00:00:00") as frozen_time:
93+
assert isinstance(frozen_time, FrozenDateTimeFactory)
94+
temporal_auth = expiring_auth(auth1, expires_at)
95+
provider = mocker.AsyncMock(return_value=temporal_auth)
96+
manager: AsyncAuthManager = expiration_based_auth_manager(provider)
9897

99-
provider.assert_not_called()
100-
assert await manager.get_auth() is auth1
101-
provider.assert_awaited_once()
102-
provider.reset_mock()
98+
provider.assert_not_called()
99+
assert await manager.get_auth() is auth1
100+
provider.assert_awaited_once()
101+
provider.reset_mock()
103102

104-
provider.return_value = expiring_auth(auth2)
103+
provider.return_value = expiring_auth(auth2)
105104

106-
await manager.on_auth_expired(("something", "else"))
107-
assert await manager.get_auth() is auth1
108-
provider.assert_not_called()
105+
await manager.on_auth_expired(("something", "else"))
106+
assert await manager.get_auth() is auth1
107+
provider.assert_not_called()
109108

110-
await manager.on_auth_expired(auth1)
111-
provider.assert_awaited_once()
112-
provider.reset_mock()
113-
assert await manager.get_auth() is auth2
114-
provider.assert_not_called()
109+
await manager.on_auth_expired(auth1)
110+
provider.assert_awaited_once()
111+
provider.reset_mock()
112+
assert await manager.get_auth() is auth2
113+
provider.assert_not_called()
115114

116115

117116
@mark_async_test
118117
@pytest.mark.parametrize(("auth1", "auth2"),
119118
itertools.product(SAMPLE_AUTHS, repeat=2))
120-
@pytest.mark.parametrize("expires_in", (None, -1, 1., 1, 1000.))
119+
@pytest.mark.parametrize("expires_at", (None, -1, 1., 1, 1000.))
121120
async def test_expiration_based_manager_time_expiry(
122121
auth1: t.Union[t.Tuple[str, str], Auth, None],
123122
auth2: t.Union[t.Tuple[str, str], Auth, None],
124-
expires_in: t.Union[float, int, None],
123+
expires_at: t.Optional[float],
125124
mocker
126125
) -> None:
127-
with freeze_time() as frozen_time:
126+
with freeze_time("1970-01-01 00:00:00") as frozen_time:
128127
assert isinstance(frozen_time, FrozenDateTimeFactory)
129-
if expires_in is None or expires_in >= 0:
130-
temporal_auth = expiring_auth(auth1, expires_in)
128+
if expires_at is None or expires_at >= 0:
129+
temporal_auth = expiring_auth(auth1, expires_at)
131130
else:
132131
temporal_auth = expiring_auth(auth1)
133132
provider = mocker.AsyncMock(return_value=temporal_auth)
@@ -140,12 +139,12 @@ async def test_expiration_based_manager_time_expiry(
140139

141140
provider.return_value = expiring_auth(auth2)
142141

143-
if expires_in is None or expires_in < 0:
142+
if expires_at is None or expires_at < 0:
144143
frozen_time.tick(1_000_000)
145144
assert await manager.get_auth() is auth1
146145
provider.assert_not_called()
147146
else:
148-
frozen_time.tick(expires_in - 0.000001)
147+
frozen_time.tick(expires_at - 0.000001)
149148
assert await manager.get_auth() is auth1
150149
provider.assert_not_called()
151150
frozen_time.tick(0.000002)

0 commit comments

Comments
 (0)