Skip to content

JWT Connection Support #14

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 3 commits into from
Sep 1, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 20 additions & 6 deletions arangoasync/auth.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
"JwtToken",
]

import time
from dataclasses import dataclass

import jwt
Expand All @@ -27,24 +28,24 @@ class JwtToken:
"""JWT token.

Args:
token (str | bytes): JWT token.
token (str): JWT token.

Raises:
TypeError: If the token type is not str or bytes.
JWTExpiredError: If the token expired.
jwt.ExpiredSignatureError: If the token expired.
"""

def __init__(self, token: str | bytes) -> None:
def __init__(self, token: str) -> None:
self._token = token
self._validate()

@property
def token(self) -> str | bytes:
def token(self) -> str:
"""Get token."""
return self._token

@token.setter
def token(self, token: str | bytes) -> None:
def token(self, token: str) -> None:
"""Set token.

Raises:
Expand All @@ -53,9 +54,22 @@ def token(self, token: str | bytes) -> None:
self._token = token
self._validate()

def needs_refresh(self, leeway: int = 0) -> bool:
"""Check if the token needs to be refreshed.

Args:
leeway (int): Leeway in seconds, before official expiration,
when to consider the token expired.

Returns:
bool: True if the token needs to be refreshed, False otherwise.
"""
refresh: bool = int(time.time()) > self._token_exp - leeway
return refresh

def _validate(self) -> None:
"""Validate the token."""
if type(self._token) not in (str, bytes):
if type(self._token) is not str:
raise TypeError("Token must be str or bytes")

jwt_payload = jwt.decode(
Expand Down
4 changes: 2 additions & 2 deletions arangoasync/compression.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,11 +120,11 @@ def level(self, value: int) -> None:
self._level = value

@property
def accept_encoding(self) -> str | None:
def accept_encoding(self) -> Optional[str]:
return self._accept_encoding

@accept_encoding.setter
def accept_encoding(self, value: AcceptEncoding | None) -> None:
def accept_encoding(self, value: Optional[AcceptEncoding]) -> None:
self._accept_encoding = value.name.lower() if value else None

@property
Expand Down
162 changes: 158 additions & 4 deletions arangoasync/connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,14 +3,19 @@
"BasicConnection",
]

import json
from abc import ABC, abstractmethod
from typing import Any, List, Optional

from arangoasync.auth import Auth
import jwt

from arangoasync.auth import Auth, JwtToken
from arangoasync.compression import CompressionManager, DefaultCompressionManager
from arangoasync.exceptions import (
AuthHeaderError,
ClientConnectionAbortedError,
ClientConnectionError,
ConnectionAbortedError,
JWTRefreshError,
ServerConnectionError,
)
from arangoasync.http import HTTPClient
Expand Down Expand Up @@ -63,6 +68,7 @@ def prep_response(self, request: Request, resp: Response) -> Response:
Raises:
ServerConnectionError: If the response status code is not successful.
"""
# TODO needs refactoring such that it does not throw
resp.is_success = 200 <= resp.status_code < 300
if resp.status_code in {401, 403}:
raise ServerConnectionError(resp, request, "Authentication failed.")
Expand Down Expand Up @@ -97,7 +103,7 @@ async def process_request(self, request: Request) -> Response:
self._host_resolver.change_host()
host_index = self._host_resolver.get_host_index()

raise ConnectionAbortedError(
raise ClientConnectionAbortedError(
f"Can't connect to host(s) within limit ({self._host_resolver.max_tries})"
)

Expand All @@ -111,6 +117,7 @@ async def ping(self) -> int:
ServerConnectionError: If the response status code is not successful.
"""
request = Request(method=Method.GET, endpoint="/_api/collection")
request.headers = {"abde": "fghi"}
resp = await self.send_request(request)
return resp.status_code

Expand Down Expand Up @@ -154,7 +161,18 @@ def __init__(
self._auth = auth

async def send_request(self, request: Request) -> Response:
"""Send an HTTP request to the ArangoDB server."""
"""Send an HTTP request to the ArangoDB server.

Args:
request (Request): HTTP request.

Returns:
Response: HTTP response

Raises:
ArangoClientError: If an error occurred from the client side.
ArangoServerError: If an error occurred from the server side.
"""
if request.data is not None and self._compression.needs_compression(
request.data
):
Expand All @@ -169,3 +187,139 @@ async def send_request(self, request: Request) -> Response:
request.auth = self._auth

return await self.process_request(request)


class JwtConnection(BaseConnection):
"""Connection to a specific ArangoDB database, using JWT authentication.

Providing login information (username and password), allows to refresh the JWT.

Args:
sessions (list): List of client sessions.
host_resolver (HostResolver): Host resolver.
http_client (HTTPClient): HTTP client.
db_name (str): Database name.
compression (CompressionManager | None): Compression manager.
auth (Auth | None): Authentication information.
token (JwtToken | None): JWT token.

Raises:
ValueError: If neither token nor auth is provided.
"""

def __init__(
self,
sessions: List[Any],
host_resolver: HostResolver,
http_client: HTTPClient,
db_name: str,
compression: Optional[CompressionManager] = None,
auth: Optional[Auth] = None,
token: Optional[JwtToken] = None,
) -> None:
super().__init__(sessions, host_resolver, http_client, db_name, compression)
self._auth = auth
self._expire_leeway: int = 0
self._token: Optional[JwtToken] = None
self._auth_header: Optional[str] = None
self.token = token

if self._token is None and self._auth is None:
raise ValueError("Either token or auth must be provided.")

@property
def token(self) -> Optional[JwtToken]:
"""Get the JWT token.

Returns:
JwtToken | None: JWT token.
"""
return self._token

@token.setter
def token(self, token: Optional[JwtToken]) -> None:
"""Set the JWT token.

Args:
token (JwtToken | None): JWT token.
Setting it to None will cause the token to be automatically
refreshed on the next request, if auth information is provided.
"""
self._token = token
self._auth_header = f"bearer {self._token.token}" if self._token else None

async def refresh_token(self) -> None:
"""Refresh the JWT token.

Raises:
JWTRefreshError: If the token can't be refreshed.
"""
if self._auth is None:
raise JWTRefreshError("Auth must be provided to refresh the token.")

data = json.dumps(
dict(username=self._auth.username, password=self._auth.password),
separators=(",", ":"),
ensure_ascii=False,
)
request = Request(
method=Method.POST,
endpoint="/_open/auth",
data=data.encode("utf-8"),
)

try:
resp = await self.process_request(request)
except ClientConnectionAbortedError as e:
raise JWTRefreshError(str(e)) from e
except ServerConnectionError as e:
raise JWTRefreshError(str(e)) from e

if not resp.is_success:
raise JWTRefreshError(
f"Failed to refresh the JWT token: "
f"{resp.status_code} {resp.status_text}"
)

token = json.loads(resp.raw_body)
try:
self.token = JwtToken(token["jwt"])
except jwt.ExpiredSignatureError as e:
raise JWTRefreshError(
"Failed to refresh the JWT token: got an expired token"
) from e

async def send_request(self, request: Request) -> Response:
"""Send an HTTP request to the ArangoDB server.

Args:
request (Request): HTTP request.

Returns:
Response: HTTP response

Raises:
ArangoClientError: If an error occurred from the client side.
ArangoServerError: If an error occurred from the server side.
"""
if self._auth_header is None:
await self.refresh_token()

if self._auth_header is None:
raise AuthHeaderError("Failed to generate authorization header.")

request.headers["authorization"] = self._auth_header

try:
resp = await self.process_request(request)
if (
resp.status_code == 401 # Unauthorized
and self._token is not None
and self._token.needs_refresh(self._expire_leeway)
):
await self.refresh_token()
return await self.process_request(request) # Retry with new token
except ServerConnectionError:
# TODO modify after refactoring of prep_response, so we can inspect response
await self.refresh_token()
return await self.process_request(request) # Retry with new token
Loading
Loading