Skip to content

Commit 44f4fa0

Browse files
authored
JWT Connection Support (#14)
* Working on JWT Connection Support * JWT Connection working * Fixing liter error
1 parent bd76c61 commit 44f4fa0

File tree

8 files changed

+1601
-16
lines changed

8 files changed

+1601
-16
lines changed

arangoasync/auth.py

+20-6
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
"JwtToken",
44
]
55

6+
import time
67
from dataclasses import dataclass
78

89
import jwt
@@ -27,24 +28,24 @@ class JwtToken:
2728
"""JWT token.
2829
2930
Args:
30-
token (str | bytes): JWT token.
31+
token (str): JWT token.
3132
3233
Raises:
3334
TypeError: If the token type is not str or bytes.
34-
JWTExpiredError: If the token expired.
35+
jwt.ExpiredSignatureError: If the token expired.
3536
"""
3637

37-
def __init__(self, token: str | bytes) -> None:
38+
def __init__(self, token: str) -> None:
3839
self._token = token
3940
self._validate()
4041

4142
@property
42-
def token(self) -> str | bytes:
43+
def token(self) -> str:
4344
"""Get token."""
4445
return self._token
4546

4647
@token.setter
47-
def token(self, token: str | bytes) -> None:
48+
def token(self, token: str) -> None:
4849
"""Set token.
4950
5051
Raises:
@@ -53,9 +54,22 @@ def token(self, token: str | bytes) -> None:
5354
self._token = token
5455
self._validate()
5556

57+
def needs_refresh(self, leeway: int = 0) -> bool:
58+
"""Check if the token needs to be refreshed.
59+
60+
Args:
61+
leeway (int): Leeway in seconds, before official expiration,
62+
when to consider the token expired.
63+
64+
Returns:
65+
bool: True if the token needs to be refreshed, False otherwise.
66+
"""
67+
refresh: bool = int(time.time()) > self._token_exp - leeway
68+
return refresh
69+
5670
def _validate(self) -> None:
5771
"""Validate the token."""
58-
if type(self._token) not in (str, bytes):
72+
if type(self._token) is not str:
5973
raise TypeError("Token must be str or bytes")
6074

6175
jwt_payload = jwt.decode(

arangoasync/compression.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -120,11 +120,11 @@ def level(self, value: int) -> None:
120120
self._level = value
121121

122122
@property
123-
def accept_encoding(self) -> str | None:
123+
def accept_encoding(self) -> Optional[str]:
124124
return self._accept_encoding
125125

126126
@accept_encoding.setter
127-
def accept_encoding(self, value: AcceptEncoding | None) -> None:
127+
def accept_encoding(self, value: Optional[AcceptEncoding]) -> None:
128128
self._accept_encoding = value.name.lower() if value else None
129129

130130
@property

arangoasync/connection.py

+158-4
Original file line numberDiff line numberDiff line change
@@ -3,14 +3,19 @@
33
"BasicConnection",
44
]
55

6+
import json
67
from abc import ABC, abstractmethod
78
from typing import Any, List, Optional
89

9-
from arangoasync.auth import Auth
10+
import jwt
11+
12+
from arangoasync.auth import Auth, JwtToken
1013
from arangoasync.compression import CompressionManager, DefaultCompressionManager
1114
from arangoasync.exceptions import (
15+
AuthHeaderError,
16+
ClientConnectionAbortedError,
1217
ClientConnectionError,
13-
ConnectionAbortedError,
18+
JWTRefreshError,
1419
ServerConnectionError,
1520
)
1621
from arangoasync.http import HTTPClient
@@ -63,6 +68,7 @@ def prep_response(self, request: Request, resp: Response) -> Response:
6368
Raises:
6469
ServerConnectionError: If the response status code is not successful.
6570
"""
71+
# TODO needs refactoring such that it does not throw
6672
resp.is_success = 200 <= resp.status_code < 300
6773
if resp.status_code in {401, 403}:
6874
raise ServerConnectionError(resp, request, "Authentication failed.")
@@ -97,7 +103,7 @@ async def process_request(self, request: Request) -> Response:
97103
self._host_resolver.change_host()
98104
host_index = self._host_resolver.get_host_index()
99105

100-
raise ConnectionAbortedError(
106+
raise ClientConnectionAbortedError(
101107
f"Can't connect to host(s) within limit ({self._host_resolver.max_tries})"
102108
)
103109

@@ -111,6 +117,7 @@ async def ping(self) -> int:
111117
ServerConnectionError: If the response status code is not successful.
112118
"""
113119
request = Request(method=Method.GET, endpoint="/_api/collection")
120+
request.headers = {"abde": "fghi"}
114121
resp = await self.send_request(request)
115122
return resp.status_code
116123

@@ -154,7 +161,18 @@ def __init__(
154161
self._auth = auth
155162

156163
async def send_request(self, request: Request) -> Response:
157-
"""Send an HTTP request to the ArangoDB server."""
164+
"""Send an HTTP request to the ArangoDB server.
165+
166+
Args:
167+
request (Request): HTTP request.
168+
169+
Returns:
170+
Response: HTTP response
171+
172+
Raises:
173+
ArangoClientError: If an error occurred from the client side.
174+
ArangoServerError: If an error occurred from the server side.
175+
"""
158176
if request.data is not None and self._compression.needs_compression(
159177
request.data
160178
):
@@ -169,3 +187,139 @@ async def send_request(self, request: Request) -> Response:
169187
request.auth = self._auth
170188

171189
return await self.process_request(request)
190+
191+
192+
class JwtConnection(BaseConnection):
193+
"""Connection to a specific ArangoDB database, using JWT authentication.
194+
195+
Providing login information (username and password), allows to refresh the JWT.
196+
197+
Args:
198+
sessions (list): List of client sessions.
199+
host_resolver (HostResolver): Host resolver.
200+
http_client (HTTPClient): HTTP client.
201+
db_name (str): Database name.
202+
compression (CompressionManager | None): Compression manager.
203+
auth (Auth | None): Authentication information.
204+
token (JwtToken | None): JWT token.
205+
206+
Raises:
207+
ValueError: If neither token nor auth is provided.
208+
"""
209+
210+
def __init__(
211+
self,
212+
sessions: List[Any],
213+
host_resolver: HostResolver,
214+
http_client: HTTPClient,
215+
db_name: str,
216+
compression: Optional[CompressionManager] = None,
217+
auth: Optional[Auth] = None,
218+
token: Optional[JwtToken] = None,
219+
) -> None:
220+
super().__init__(sessions, host_resolver, http_client, db_name, compression)
221+
self._auth = auth
222+
self._expire_leeway: int = 0
223+
self._token: Optional[JwtToken] = None
224+
self._auth_header: Optional[str] = None
225+
self.token = token
226+
227+
if self._token is None and self._auth is None:
228+
raise ValueError("Either token or auth must be provided.")
229+
230+
@property
231+
def token(self) -> Optional[JwtToken]:
232+
"""Get the JWT token.
233+
234+
Returns:
235+
JwtToken | None: JWT token.
236+
"""
237+
return self._token
238+
239+
@token.setter
240+
def token(self, token: Optional[JwtToken]) -> None:
241+
"""Set the JWT token.
242+
243+
Args:
244+
token (JwtToken | None): JWT token.
245+
Setting it to None will cause the token to be automatically
246+
refreshed on the next request, if auth information is provided.
247+
"""
248+
self._token = token
249+
self._auth_header = f"bearer {self._token.token}" if self._token else None
250+
251+
async def refresh_token(self) -> None:
252+
"""Refresh the JWT token.
253+
254+
Raises:
255+
JWTRefreshError: If the token can't be refreshed.
256+
"""
257+
if self._auth is None:
258+
raise JWTRefreshError("Auth must be provided to refresh the token.")
259+
260+
data = json.dumps(
261+
dict(username=self._auth.username, password=self._auth.password),
262+
separators=(",", ":"),
263+
ensure_ascii=False,
264+
)
265+
request = Request(
266+
method=Method.POST,
267+
endpoint="/_open/auth",
268+
data=data.encode("utf-8"),
269+
)
270+
271+
try:
272+
resp = await self.process_request(request)
273+
except ClientConnectionAbortedError as e:
274+
raise JWTRefreshError(str(e)) from e
275+
except ServerConnectionError as e:
276+
raise JWTRefreshError(str(e)) from e
277+
278+
if not resp.is_success:
279+
raise JWTRefreshError(
280+
f"Failed to refresh the JWT token: "
281+
f"{resp.status_code} {resp.status_text}"
282+
)
283+
284+
token = json.loads(resp.raw_body)
285+
try:
286+
self.token = JwtToken(token["jwt"])
287+
except jwt.ExpiredSignatureError as e:
288+
raise JWTRefreshError(
289+
"Failed to refresh the JWT token: got an expired token"
290+
) from e
291+
292+
async def send_request(self, request: Request) -> Response:
293+
"""Send an HTTP request to the ArangoDB server.
294+
295+
Args:
296+
request (Request): HTTP request.
297+
298+
Returns:
299+
Response: HTTP response
300+
301+
Raises:
302+
ArangoClientError: If an error occurred from the client side.
303+
ArangoServerError: If an error occurred from the server side.
304+
"""
305+
if self._auth_header is None:
306+
await self.refresh_token()
307+
308+
if self._auth_header is None:
309+
raise AuthHeaderError("Failed to generate authorization header.")
310+
311+
request.headers["authorization"] = self._auth_header
312+
313+
try:
314+
resp = await self.process_request(request)
315+
if (
316+
resp.status_code == 401 # Unauthorized
317+
and self._token is not None
318+
and self._token.needs_refresh(self._expire_leeway)
319+
):
320+
await self.refresh_token()
321+
return await self.process_request(request) # Retry with new token
322+
except ServerConnectionError:
323+
# TODO modify after refactoring of prep_response, so we can inspect response
324+
await self.refresh_token()
325+
return await self.process_request(request) # Retry with new token

0 commit comments

Comments
 (0)