3
3
"BasicConnection" ,
4
4
]
5
5
6
+ import json
6
7
from abc import ABC , abstractmethod
7
8
from typing import Any , List , Optional
8
9
9
- from arangoasync .auth import Auth
10
+ import jwt
11
+
12
+ from arangoasync .auth import Auth , JwtToken
10
13
from arangoasync .compression import CompressionManager , DefaultCompressionManager
11
14
from arangoasync .exceptions import (
15
+ AuthHeaderError ,
16
+ ClientConnectionAbortedError ,
12
17
ClientConnectionError ,
13
- ConnectionAbortedError ,
18
+ JWTRefreshError ,
14
19
ServerConnectionError ,
15
20
)
16
21
from arangoasync .http import HTTPClient
@@ -63,6 +68,7 @@ def prep_response(self, request: Request, resp: Response) -> Response:
63
68
Raises:
64
69
ServerConnectionError: If the response status code is not successful.
65
70
"""
71
+ # TODO needs refactoring such that it does not throw
66
72
resp .is_success = 200 <= resp .status_code < 300
67
73
if resp .status_code in {401 , 403 }:
68
74
raise ServerConnectionError (resp , request , "Authentication failed." )
@@ -97,7 +103,7 @@ async def process_request(self, request: Request) -> Response:
97
103
self ._host_resolver .change_host ()
98
104
host_index = self ._host_resolver .get_host_index ()
99
105
100
- raise ConnectionAbortedError (
106
+ raise ClientConnectionAbortedError (
101
107
f"Can't connect to host(s) within limit ({ self ._host_resolver .max_tries } )"
102
108
)
103
109
@@ -111,6 +117,7 @@ async def ping(self) -> int:
111
117
ServerConnectionError: If the response status code is not successful.
112
118
"""
113
119
request = Request (method = Method .GET , endpoint = "/_api/collection" )
120
+ request .headers = {"abde" : "fghi" }
114
121
resp = await self .send_request (request )
115
122
return resp .status_code
116
123
@@ -154,7 +161,18 @@ def __init__(
154
161
self ._auth = auth
155
162
156
163
async def send_request (self , request : Request ) -> Response :
157
- """Send an HTTP request to the ArangoDB server."""
164
+ """Send an HTTP request to the ArangoDB server.
165
+
166
+ Args:
167
+ request (Request): HTTP request.
168
+
169
+ Returns:
170
+ Response: HTTP response
171
+
172
+ Raises:
173
+ ArangoClientError: If an error occurred from the client side.
174
+ ArangoServerError: If an error occurred from the server side.
175
+ """
158
176
if request .data is not None and self ._compression .needs_compression (
159
177
request .data
160
178
):
@@ -169,3 +187,139 @@ async def send_request(self, request: Request) -> Response:
169
187
request .auth = self ._auth
170
188
171
189
return await self .process_request (request )
190
+
191
+
192
+ class JwtConnection (BaseConnection ):
193
+ """Connection to a specific ArangoDB database, using JWT authentication.
194
+
195
+ Providing login information (username and password), allows to refresh the JWT.
196
+
197
+ Args:
198
+ sessions (list): List of client sessions.
199
+ host_resolver (HostResolver): Host resolver.
200
+ http_client (HTTPClient): HTTP client.
201
+ db_name (str): Database name.
202
+ compression (CompressionManager | None): Compression manager.
203
+ auth (Auth | None): Authentication information.
204
+ token (JwtToken | None): JWT token.
205
+
206
+ Raises:
207
+ ValueError: If neither token nor auth is provided.
208
+ """
209
+
210
+ def __init__ (
211
+ self ,
212
+ sessions : List [Any ],
213
+ host_resolver : HostResolver ,
214
+ http_client : HTTPClient ,
215
+ db_name : str ,
216
+ compression : Optional [CompressionManager ] = None ,
217
+ auth : Optional [Auth ] = None ,
218
+ token : Optional [JwtToken ] = None ,
219
+ ) -> None :
220
+ super ().__init__ (sessions , host_resolver , http_client , db_name , compression )
221
+ self ._auth = auth
222
+ self ._expire_leeway : int = 0
223
+ self ._token : Optional [JwtToken ] = None
224
+ self ._auth_header : Optional [str ] = None
225
+ self .token = token
226
+
227
+ if self ._token is None and self ._auth is None :
228
+ raise ValueError ("Either token or auth must be provided." )
229
+
230
+ @property
231
+ def token (self ) -> Optional [JwtToken ]:
232
+ """Get the JWT token.
233
+
234
+ Returns:
235
+ JwtToken | None: JWT token.
236
+ """
237
+ return self ._token
238
+
239
+ @token .setter
240
+ def token (self , token : Optional [JwtToken ]) -> None :
241
+ """Set the JWT token.
242
+
243
+ Args:
244
+ token (JwtToken | None): JWT token.
245
+ Setting it to None will cause the token to be automatically
246
+ refreshed on the next request, if auth information is provided.
247
+ """
248
+ self ._token = token
249
+ self ._auth_header = f"bearer { self ._token .token } " if self ._token else None
250
+
251
+ async def refresh_token (self ) -> None :
252
+ """Refresh the JWT token.
253
+
254
+ Raises:
255
+ JWTRefreshError: If the token can't be refreshed.
256
+ """
257
+ if self ._auth is None :
258
+ raise JWTRefreshError ("Auth must be provided to refresh the token." )
259
+
260
+ data = json .dumps (
261
+ dict (username = self ._auth .username , password = self ._auth .password ),
262
+ separators = ("," , ":" ),
263
+ ensure_ascii = False ,
264
+ )
265
+ request = Request (
266
+ method = Method .POST ,
267
+ endpoint = "/_open/auth" ,
268
+ data = data .encode ("utf-8" ),
269
+ )
270
+
271
+ try :
272
+ resp = await self .process_request (request )
273
+ except ClientConnectionAbortedError as e :
274
+ raise JWTRefreshError (str (e )) from e
275
+ except ServerConnectionError as e :
276
+ raise JWTRefreshError (str (e )) from e
277
+
278
+ if not resp .is_success :
279
+ raise JWTRefreshError (
280
+ f"Failed to refresh the JWT token: "
281
+ f"{ resp .status_code } { resp .status_text } "
282
+ )
283
+
284
+ token = json .loads (resp .raw_body )
285
+ try :
286
+ self .token = JwtToken (token ["jwt" ])
287
+ except jwt .ExpiredSignatureError as e :
288
+ raise JWTRefreshError (
289
+ "Failed to refresh the JWT token: got an expired token"
290
+ ) from e
291
+
292
+ async def send_request (self , request : Request ) -> Response :
293
+ """Send an HTTP request to the ArangoDB server.
294
+
295
+ Args:
296
+ request (Request): HTTP request.
297
+
298
+ Returns:
299
+ Response: HTTP response
300
+
301
+ Raises:
302
+ ArangoClientError: If an error occurred from the client side.
303
+ ArangoServerError: If an error occurred from the server side.
304
+ """
305
+ if self ._auth_header is None :
306
+ await self .refresh_token ()
307
+
308
+ if self ._auth_header is None :
309
+ raise AuthHeaderError ("Failed to generate authorization header." )
310
+
311
+ request .headers ["authorization" ] = self ._auth_header
312
+
313
+ try :
314
+ resp = await self .process_request (request )
315
+ if (
316
+ resp .status_code == 401 # Unauthorized
317
+ and self ._token is not None
318
+ and self ._token .needs_refresh (self ._expire_leeway )
319
+ ):
320
+ await self .refresh_token ()
321
+ return await self .process_request (request ) # Retry with new token
322
+ except ServerConnectionError :
323
+ # TODO modify after refactoring of prep_response, so we can inspect response
324
+ await self .refresh_token ()
325
+ return await self .process_request (request ) # Retry with new token
0 commit comments