-
Notifications
You must be signed in to change notification settings - Fork 2
/
Copy pathconnection.py
171 lines (139 loc) · 5.53 KB
/
connection.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
__all__ = [
"BaseConnection",
"BasicConnection",
]
from abc import ABC, abstractmethod
from typing import Any, List, Optional
from arangoasync.auth import Auth
from arangoasync.compression import CompressionManager, DefaultCompressionManager
from arangoasync.exceptions import (
ClientConnectionError,
ConnectionAbortedError,
ServerConnectionError,
)
from arangoasync.http import HTTPClient
from arangoasync.request import Method, Request
from arangoasync.resolver import HostResolver
from arangoasync.response import Response
class BaseConnection(ABC):
"""Blueprint for connection to a specific ArangoDB database.
Args:
sessions (list): List of client sessions.
host_resolver (HostResolver): Host resolver.
http_client (HTTPClient): HTTP client.
db_name (str): Database name.
compression (CompressionManager | None): Compression manager.
"""
def __init__(
self,
sessions: List[Any],
host_resolver: HostResolver,
http_client: HTTPClient,
db_name: str,
compression: Optional[CompressionManager] = None,
) -> None:
self._sessions = sessions
self._db_endpoint = f"/_db/{db_name}"
self._host_resolver = host_resolver
self._http_client = http_client
self._db_name = db_name
self._compression = compression or DefaultCompressionManager()
@property
def db_name(self) -> str:
"""Return the database name."""
return self._db_name
def prep_response(self, request: Request, resp: Response) -> Response:
"""Prepare response for return.
Args:
request (Request): Request object.
resp (Response): Response object.
Returns:
Response: Response object
Raises:
ServerConnectionError: If the response status code is not successful.
"""
resp.is_success = 200 <= resp.status_code < 300
if resp.status_code in {401, 403}:
raise ServerConnectionError(resp, request, "Authentication failed.")
if not resp.is_success:
raise ServerConnectionError(resp, request, "Bad server response.")
return resp
async def process_request(self, request: Request) -> Response:
"""Process request, potentially trying multiple hosts.
Args:
request (Request): Request object.
Returns:
Response: Response object.
Raises:
ConnectionAbortedError: If can't connect to host(s) within limit.
"""
host_index = self._host_resolver.get_host_index()
for tries in range(self._host_resolver.max_tries):
try:
resp = await self._http_client.send_request(
self._sessions[host_index], request
)
return self.prep_response(request, resp)
except ClientConnectionError:
ex_host_index = host_index
host_index = self._host_resolver.get_host_index()
if ex_host_index == host_index:
self._host_resolver.change_host()
host_index = self._host_resolver.get_host_index()
raise ConnectionAbortedError(
f"Can't connect to host(s) within limit ({self._host_resolver.max_tries})"
)
async def ping(self) -> int:
"""Ping host to check if connection is established.
Returns:
int: Response status code.
Raises:
ServerConnectionError: If the response status code is not successful.
"""
request = Request(method=Method.GET, endpoint="/_api/collection")
resp = await self.send_request(request)
return resp.status_code
@abstractmethod
async def send_request(self, request: Request) -> Response: # pragma: no cover
"""Send an HTTP request to the ArangoDB server.
Args:
request (Request): HTTP request.
Returns:
Response: HTTP response.
"""
raise NotImplementedError
class BasicConnection(BaseConnection):
"""Connection to a specific ArangoDB database.
Allows for basic authentication to be used (username and password).
Args:
sessions (list): List of client sessions.
host_resolver (HostResolver): Host resolver.
http_client (HTTPClient): HTTP client.
db_name (str): Database name.
compression (CompressionManager | None): Compression manager.
auth (Auth | None): Authentication information.
"""
def __init__(
self,
sessions: List[Any],
host_resolver: HostResolver,
http_client: HTTPClient,
db_name: str,
compression: Optional[CompressionManager] = None,
auth: Optional[Auth] = None,
) -> None:
super().__init__(sessions, host_resolver, http_client, db_name, compression)
self._auth = auth
async def send_request(self, request: Request) -> Response:
"""Send an HTTP request to the ArangoDB server."""
if request.data is not None and self._compression.needs_compression(
request.data
):
request.data = self._compression.compress(request.data)
request.headers["content-encoding"] = self._compression.content_encoding
accept_encoding: str | None = self._compression.accept_encoding
if accept_encoding is not None:
request.headers["accept-encoding"] = accept_encoding
if self._auth:
request.auth = self._auth
return await self.process_request(request)