Skip to content

Change ExpiringAuth.expires_in to expires_at #928

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 4 commits into from
May 22, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,9 @@
See also https://github.com/neo4j/neo4j-python-driver/wiki for a full changelog.

## NEXT RELEASE
- No breaking or major changes.
- `neo4j.auth_management.ExpiringAuth`'s `expires_in` (in preview) was replaced
by `expires_at`, which is a unix timestamp.
You can use `ExpiringAuth(some_auth).expires_in(123)` instead.


## Version 5.8
Expand Down
38 changes: 14 additions & 24 deletions src/neo4j/_async/auth_management.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
from .._async_compat.concurrency import AsyncLock
from .._auth_management import (
AsyncAuthManager,
expiring_auth_has_expired,
ExpiringAuth,
)
from .._meta import preview
Expand All @@ -54,24 +55,8 @@ async def on_auth_expired(self, auth: _TAuth) -> None:
pass


class _ExpiringAuthHolder:
def __init__(self, auth: ExpiringAuth) -> None:
self._auth = auth
self._expiry = None
if auth.expires_in is not None:
self._expiry = time.monotonic() + auth.expires_in

@property
def auth(self) -> _TAuth:
return self._auth.auth

def expired(self) -> bool:
if self._expiry is None:
return False
return time.monotonic() > self._expiry

class AsyncExpirationBasedAuthManager(AsyncAuthManager):
_current_auth: t.Optional[_ExpiringAuthHolder]
_current_auth: t.Optional[ExpiringAuth]
_provider: t.Callable[[], t.Awaitable[ExpiringAuth]]
_lock: AsyncLock

Expand All @@ -85,17 +70,22 @@ def __init__(
self._lock = AsyncLock()

async def _refresh_auth(self):
self._current_auth = _ExpiringAuthHolder(await self._provider())
self._current_auth = await self._provider()
if self._current_auth is None:
raise TypeError(
"Auth provider function passed to expiration_based "
"AuthManager returned None, expected ExpiringAuth"
)

async def get_auth(self) -> _TAuth:
async with self._lock:
auth = self._current_auth
if auth is not None and not auth.expired():
return auth.auth
log.debug("[ ] _: <TEMPORAL AUTH> refreshing (time out)")
await self._refresh_auth()
assert self._current_auth is not None
return self._current_auth.auth
if auth is None or expiring_auth_has_expired(auth):
log.debug("[ ] _: <TEMPORAL AUTH> refreshing (time out)")
await self._refresh_auth()
auth = self._current_auth
assert auth is not None
return auth.auth

async def on_auth_expired(self, auth: _TAuth) -> None:
async with self._lock:
Expand Down
48 changes: 42 additions & 6 deletions src/neo4j/_auth_management.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,10 +22,15 @@


import abc
import time
import typing as t
import warnings
from dataclasses import dataclass

from ._meta import preview
from ._meta import (
preview,
PreviewWarning,
)
from .api import _TAuth


Expand All @@ -38,9 +43,11 @@ class ExpiringAuth:
:meth:`.AsyncAuthManagers.expiration_based`.

:param auth: The authentication information.
:param expires_in: The number of seconds until the authentication
information expires. If :data:`None`, the authentication information
is considered to not expire until the server explicitly indicates so.
:param expires_at:
Unix timestamp (seconds since 1970-01-01 00:00:00 UTC)
indicating when the authentication information expires.
If :data:`None`, the authentication information is considered to not
expire until the server explicitly indicates so.

**This is a preview** (see :ref:`filter-warnings-ref`).
It might be changed without following the deprecation policy.
Expand All @@ -52,8 +59,37 @@ class ExpiringAuth:

.. versionadded:: 5.8
"""
auth: _TAuth
expires_in: t.Optional[float] = None
auth: "_TAuth"
expires_at: t.Optional[float] = None

def expires_in(self, seconds: float) -> "ExpiringAuth":
"""Return a copy of this object with a new expiration time.

This is a convenience method for creating an :class:`.ExpiringAuth`
for a relative expiration time ("expires in" instead of "expires at").

>>> import time, freezegun
>>> with freezegun.freeze_time("1970-01-01 00:00:40"):
... ExpiringAuth(("user", "pass")).expires_in(2)
ExpiringAuth(auth=('user', 'pass'), expires_at=42.0)
>>> with freezegun.freeze_time("1970-01-01 00:00:40"):
... ExpiringAuth(("user", "pass"), time.time() + 2)
ExpiringAuth(auth=('user', 'pass'), expires_at=42.0)

:param seconds:
The number of seconds from now until the authentication information
expires.
"""
with warnings.catch_warnings():
warnings.filterwarnings("ignore",
message=r"^Auth managers\b.*",
category=PreviewWarning)
return ExpiringAuth(self.auth, time.time() + seconds)


def expiring_auth_has_expired(auth: ExpiringAuth) -> bool:
expires_at = auth.expires_at
return expires_at is not None and expires_at < time.time()


class AuthManager(metaclass=abc.ABCMeta):
Expand Down
38 changes: 14 additions & 24 deletions src/neo4j/_sync/auth_management.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
from .._async_compat.concurrency import Lock
from .._auth_management import (
AuthManager,
expiring_auth_has_expired,
ExpiringAuth,
)
from .._meta import preview
Expand All @@ -54,24 +55,8 @@ def on_auth_expired(self, auth: _TAuth) -> None:
pass


class _ExpiringAuthHolder:
def __init__(self, auth: ExpiringAuth) -> None:
self._auth = auth
self._expiry = None
if auth.expires_in is not None:
self._expiry = time.monotonic() + auth.expires_in

@property
def auth(self) -> _TAuth:
return self._auth.auth

def expired(self) -> bool:
if self._expiry is None:
return False
return time.monotonic() > self._expiry

class ExpirationBasedAuthManager(AuthManager):
_current_auth: t.Optional[_ExpiringAuthHolder]
_current_auth: t.Optional[ExpiringAuth]
_provider: t.Callable[[], t.Union[ExpiringAuth]]
_lock: Lock

Expand All @@ -85,17 +70,22 @@ def __init__(
self._lock = Lock()

def _refresh_auth(self):
self._current_auth = _ExpiringAuthHolder(self._provider())
self._current_auth = self._provider()
if self._current_auth is None:
raise TypeError(
"Auth provider function passed to expiration_based "
"AuthManager returned None, expected ExpiringAuth"
)

def get_auth(self) -> _TAuth:
with self._lock:
auth = self._current_auth
if auth is not None and not auth.expired():
return auth.auth
log.debug("[ ] _: <TEMPORAL AUTH> refreshing (time out)")
self._refresh_auth()
assert self._current_auth is not None
return self._current_auth.auth
if auth is None or expiring_auth_has_expired(auth):
log.debug("[ ] _: <TEMPORAL AUTH> refreshing (time out)")
self._refresh_auth()
auth = self._current_auth
assert auth is not None
return auth.auth

def on_auth_expired(self, auth: _TAuth) -> None:
with self._lock:
Expand Down
9 changes: 4 additions & 5 deletions testkitbackend/_async/requests.py
Original file line number Diff line number Diff line change
Expand Up @@ -271,13 +271,12 @@ async def ExpirationBasedAuthTokenProviderCompleted(backend, data):
"AuthTokenAndExpiration")
temp_auth_data = temp_auth_data["data"]
auth_token = fromtestkit.to_auth_token(temp_auth_data, "auth")
if temp_auth_data["expiresInMs"] is not None:
expires_in = temp_auth_data["expiresInMs"] / 1000
else:
expires_in = None
with warning_check(neo4j.PreviewWarning,
"Auth managers are a preview feature."):
expiring_auth = ExpiringAuth(auth_token, expires_in)
expiring_auth = ExpiringAuth(auth_token)
if temp_auth_data["expiresInMs"] is not None:
expires_in = temp_auth_data["expiresInMs"] / 1000
expiring_auth = expiring_auth.expires_in(expires_in)

backend.expiring_auth_token_supplies[data["requestId"]] = expiring_auth

Expand Down
9 changes: 4 additions & 5 deletions testkitbackend/_sync/requests.py
Original file line number Diff line number Diff line change
Expand Up @@ -271,13 +271,12 @@ def ExpirationBasedAuthTokenProviderCompleted(backend, data):
"AuthTokenAndExpiration")
temp_auth_data = temp_auth_data["data"]
auth_token = fromtestkit.to_auth_token(temp_auth_data, "auth")
if temp_auth_data["expiresInMs"] is not None:
expires_in = temp_auth_data["expiresInMs"] / 1000
else:
expires_in = None
with warning_check(neo4j.PreviewWarning,
"Auth managers are a preview feature."):
expiring_auth = ExpiringAuth(auth_token, expires_in)
expiring_auth = ExpiringAuth(auth_token)
if temp_auth_data["expiresInMs"] is not None:
expires_in = temp_auth_data["expiresInMs"] / 1000
expiring_auth = expiring_auth.expires_in(expires_in)

backend.expiring_auth_token_supplies[data["requestId"]] = expiring_auth

Expand Down
55 changes: 27 additions & 28 deletions tests/unit/async_/test_auth_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,52 +82,51 @@ async def test_static_manager(
@mark_async_test
@pytest.mark.parametrize(("auth1", "auth2"),
itertools.product(SAMPLE_AUTHS, repeat=2))
@pytest.mark.parametrize("expires_in", (None, -1, 1., 1, 1000.))
@pytest.mark.parametrize("expires_at", (None, .001, 1, 1000.))
async def test_expiration_based_manager_manual_expiry(
auth1: t.Union[t.Tuple[str, str], Auth, None],
auth2: t.Union[t.Tuple[str, str], Auth, None],
expires_in: t.Union[float, int],
expires_at: t.Optional[float],
mocker
) -> None:
if expires_in is None or expires_in >= 0:
temporal_auth = expiring_auth(auth1, expires_in)
else:
temporal_auth = expiring_auth(auth1)
provider = mocker.AsyncMock(return_value=temporal_auth)
manager: AsyncAuthManager = expiration_based_auth_manager(provider)
with freeze_time("1970-01-01 00:00:00") as frozen_time:
assert isinstance(frozen_time, FrozenDateTimeFactory)
temporal_auth = expiring_auth(auth1, expires_at)
provider = mocker.AsyncMock(return_value=temporal_auth)
manager: AsyncAuthManager = expiration_based_auth_manager(provider)

provider.assert_not_called()
assert await manager.get_auth() is auth1
provider.assert_awaited_once()
provider.reset_mock()
provider.assert_not_called()
assert await manager.get_auth() is auth1
provider.assert_awaited_once()
provider.reset_mock()

provider.return_value = expiring_auth(auth2)
provider.return_value = expiring_auth(auth2)

await manager.on_auth_expired(("something", "else"))
assert await manager.get_auth() is auth1
provider.assert_not_called()
await manager.on_auth_expired(("something", "else"))
assert await manager.get_auth() is auth1
provider.assert_not_called()

await manager.on_auth_expired(auth1)
provider.assert_awaited_once()
provider.reset_mock()
assert await manager.get_auth() is auth2
provider.assert_not_called()
await manager.on_auth_expired(auth1)
provider.assert_awaited_once()
provider.reset_mock()
assert await manager.get_auth() is auth2
provider.assert_not_called()


@mark_async_test
@pytest.mark.parametrize(("auth1", "auth2"),
itertools.product(SAMPLE_AUTHS, repeat=2))
@pytest.mark.parametrize("expires_in", (None, -1, 1., 1, 1000.))
@pytest.mark.parametrize("expires_at", (None, -1, 1., 1, 1000.))
async def test_expiration_based_manager_time_expiry(
auth1: t.Union[t.Tuple[str, str], Auth, None],
auth2: t.Union[t.Tuple[str, str], Auth, None],
expires_in: t.Union[float, int, None],
expires_at: t.Optional[float],
mocker
) -> None:
with freeze_time() as frozen_time:
with freeze_time("1970-01-01 00:00:00") as frozen_time:
assert isinstance(frozen_time, FrozenDateTimeFactory)
if expires_in is None or expires_in >= 0:
temporal_auth = expiring_auth(auth1, expires_in)
if expires_at is None or expires_at >= 0:
temporal_auth = expiring_auth(auth1, expires_at)
else:
temporal_auth = expiring_auth(auth1)
provider = mocker.AsyncMock(return_value=temporal_auth)
Expand All @@ -140,12 +139,12 @@ async def test_expiration_based_manager_time_expiry(

provider.return_value = expiring_auth(auth2)

if expires_in is None or expires_in < 0:
if expires_at is None or expires_at < 0:
frozen_time.tick(1_000_000)
assert await manager.get_auth() is auth1
provider.assert_not_called()
else:
frozen_time.tick(expires_in - 0.000001)
frozen_time.tick(expires_at - 0.000001)
assert await manager.get_auth() is auth1
provider.assert_not_called()
frozen_time.tick(0.000002)
Expand Down
Loading