Skip to content

Commit bd76c61

Browse files
authored
Improved tests (#13)
* Improved tests * Marking test async
1 parent b67f60d commit bd76c61

File tree

6 files changed

+157
-38
lines changed

6 files changed

+157
-38
lines changed

arangoasync/compression.py

+32-10
Original file line numberDiff line numberDiff line change
@@ -53,6 +53,7 @@ def compress(self, data: str | bytes) -> bytes:
5353
"""
5454
raise NotImplementedError
5555

56+
@property
5657
@abstractmethod
5758
def content_encoding(self) -> str:
5859
"""Return the content encoding.
@@ -65,6 +66,7 @@ def content_encoding(self) -> str:
6566
"""
6667
raise NotImplementedError
6768

69+
@property
6870
@abstractmethod
6971
def accept_encoding(self) -> str | None:
7072
"""Return the accept encoding.
@@ -101,18 +103,38 @@ def __init__(
101103
self._content_encoding = ContentEncoding.DEFLATE.name.lower()
102104
self._accept_encoding = accept.name.lower() if accept else None
103105

104-
def needs_compression(self, data: str | bytes) -> bool:
105-
return self._threshold != -1 and len(data) >= self._threshold
106+
@property
107+
def threshold(self) -> int:
108+
return self._threshold
106109

107-
def compress(self, data: str | bytes) -> bytes:
108-
if data is not None:
109-
if isinstance(data, bytes):
110-
return zlib.compress(data, self._level)
111-
return zlib.compress(data.encode("utf-8"), self._level)
112-
return b""
110+
@threshold.setter
111+
def threshold(self, value: int) -> None:
112+
self._threshold = value
113113

114-
def content_encoding(self) -> str:
115-
return self._content_encoding
114+
@property
115+
def level(self) -> int:
116+
return self._level
117+
118+
@level.setter
119+
def level(self, value: int) -> None:
120+
self._level = value
116121

122+
@property
117123
def accept_encoding(self) -> str | None:
118124
return self._accept_encoding
125+
126+
@accept_encoding.setter
127+
def accept_encoding(self, value: AcceptEncoding | None) -> None:
128+
self._accept_encoding = value.name.lower() if value else None
129+
130+
@property
131+
def content_encoding(self) -> str:
132+
return self._content_encoding
133+
134+
def needs_compression(self, data: str | bytes) -> bool:
135+
return self._threshold != -1 and len(data) >= self._threshold
136+
137+
def compress(self, data: str | bytes) -> bytes:
138+
if isinstance(data, bytes):
139+
return zlib.compress(data, self._level)
140+
return zlib.compress(data.encode("utf-8"), self._level)

arangoasync/connection.py

+5-7
Original file line numberDiff line numberDiff line change
@@ -64,8 +64,10 @@ def prep_response(self, request: Request, resp: Response) -> Response:
6464
ServerConnectionError: If the response status code is not successful.
6565
"""
6666
resp.is_success = 200 <= resp.status_code < 300
67+
if resp.status_code in {401, 403}:
68+
raise ServerConnectionError(resp, request, "Authentication failed.")
6769
if not resp.is_success:
68-
raise ServerConnectionError(resp, request)
70+
raise ServerConnectionError(resp, request, "Bad server response.")
6971
return resp
7072

7173
async def process_request(self, request: Request) -> Response:
@@ -110,10 +112,6 @@ async def ping(self) -> int:
110112
"""
111113
request = Request(method=Method.GET, endpoint="/_api/collection")
112114
resp = await self.send_request(request)
113-
if resp.status_code in {401, 403}:
114-
raise ServerConnectionError(resp, request, "Authentication failed.")
115-
if not resp.is_success:
116-
raise ServerConnectionError(resp, request, "Bad server response.")
117115
return resp.status_code
118116

119117
@abstractmethod
@@ -161,9 +159,9 @@ async def send_request(self, request: Request) -> Response:
161159
request.data
162160
):
163161
request.data = self._compression.compress(request.data)
164-
request.headers["content-encoding"] = self._compression.content_encoding()
162+
request.headers["content-encoding"] = self._compression.content_encoding
165163

166-
accept_encoding: str | None = self._compression.accept_encoding()
164+
accept_encoding: str | None = self._compression.accept_encoding
167165
if accept_encoding is not None:
168166
request.headers["accept-encoding"] = accept_encoding
169167

arangoasync/http.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -151,8 +151,8 @@ async def send_request(
151151
async with session.request(
152152
request.method.name,
153153
request.endpoint,
154-
headers=request.headers,
155-
params=request.params,
154+
headers=request.normalized_headers(),
155+
params=request.normalized_params(),
156156
data=request.data,
157157
auth=auth,
158158
) as response:

arangoasync/request.py

+8-16
Original file line numberDiff line numberDiff line change
@@ -63,18 +63,14 @@ def __init__(
6363
) -> None:
6464
self.method: Method = method
6565
self.endpoint: str = endpoint
66-
self.headers: RequestHeaders = self._normalize_headers(headers)
67-
self.params: Params = self._normalize_params(params)
66+
self.headers: RequestHeaders = headers or dict()
67+
self.params: Params = params or dict()
6868
self.data: Optional[bytes] = data
6969
self.auth: Optional[Auth] = auth
7070

71-
@staticmethod
72-
def _normalize_headers(headers: Optional[RequestHeaders]) -> RequestHeaders:
71+
def normalized_headers(self) -> RequestHeaders:
7372
"""Normalize request headers.
7473
75-
Parameters:
76-
headers (dict | None): Request headers.
77-
7874
Returns:
7975
dict: Normalized request headers.
8076
"""
@@ -85,26 +81,22 @@ def _normalize_headers(headers: Optional[RequestHeaders]) -> RequestHeaders:
8581
"x-arango-driver": driver_header,
8682
}
8783

88-
if headers is not None:
89-
for key, value in headers.items():
84+
if self.headers is not None:
85+
for key, value in self.headers.items():
9086
normalized_headers[key.lower()] = value
9187

9288
return normalized_headers
9389

94-
@staticmethod
95-
def _normalize_params(params: Optional[Params]) -> Params:
90+
def normalized_params(self) -> Params:
9691
"""Normalize URL parameters.
9792
98-
Parameters:
99-
params (dict | None): URL parameters.
100-
10193
Returns:
10294
dict: Normalized URL parameters.
10395
"""
10496
normalized_params: Params = {}
10597

106-
if params is not None:
107-
for key, value in params.items():
98+
if self.params is not None:
99+
for key, value in self.params.items():
108100
if isinstance(value, bool):
109101
value = int(value)
110102
normalized_params[key] = str(value)

tests/test_compression.py

+21-2
Original file line numberDiff line numberDiff line change
@@ -16,5 +16,24 @@ def test_DefaultCompressionManager_compress():
1616
data = "a" * 10 + "b" * 10
1717
assert manager.needs_compression(data)
1818
assert len(manager.compress(data)) < len(data)
19-
assert manager.content_encoding() == "deflate"
20-
assert manager.accept_encoding() == "deflate"
19+
assert manager.content_encoding == "deflate"
20+
assert manager.accept_encoding == "deflate"
21+
data = b"a" * 10 + b"b" * 10
22+
assert manager.needs_compression(data)
23+
assert len(manager.compress(data)) < len(data)
24+
25+
26+
def test_DefaultCompressionManager_properties():
27+
manager = DefaultCompressionManager(
28+
threshold=1, level=9, accept=AcceptEncoding.DEFLATE
29+
)
30+
assert manager.threshold == 1
31+
assert manager.level == 9
32+
assert manager.accept_encoding == "deflate"
33+
assert manager.content_encoding == "deflate"
34+
manager.threshold = 10
35+
assert manager.threshold == 10
36+
manager.level = 2
37+
assert manager.level == 2
38+
manager.accept_encoding = AcceptEncoding.GZIP
39+
assert manager.accept_encoding == "gzip"

tests/test_connection.py

+89-1
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,19 @@
1+
import zlib
2+
13
import pytest
24

35
from arangoasync.auth import Auth
6+
from arangoasync.compression import AcceptEncoding, DefaultCompressionManager
47
from arangoasync.connection import BasicConnection
5-
from arangoasync.exceptions import ServerConnectionError
8+
from arangoasync.exceptions import (
9+
ClientConnectionError,
10+
ConnectionAbortedError,
11+
ServerConnectionError,
12+
)
613
from arangoasync.http import AioHTTPClient
14+
from arangoasync.request import Method, Request
715
from arangoasync.resolver import DefaultHostResolver
16+
from arangoasync.response import Response
817

918

1019
@pytest.mark.asyncio
@@ -40,5 +49,84 @@ async def test_BasicConnection_ping_success(
4049
auth=Auth(username=root, password=password),
4150
)
4251

52+
assert connection.db_name == sys_db_name
4353
status_code = await connection.ping()
4454
assert status_code == 200
55+
56+
57+
@pytest.mark.asyncio
58+
async def test_BasicConnection_with_compression(
59+
client_session, url, sys_db_name, root, password
60+
):
61+
client = AioHTTPClient()
62+
session = client_session(client, url)
63+
resolver = DefaultHostResolver(1)
64+
compression = DefaultCompressionManager(
65+
threshold=2, level=5, accept=AcceptEncoding.DEFLATE
66+
)
67+
68+
connection = BasicConnection(
69+
sessions=[session],
70+
host_resolver=resolver,
71+
http_client=client,
72+
db_name=sys_db_name,
73+
auth=Auth(username=root, password=password),
74+
compression=compression,
75+
)
76+
77+
data = b"a" * 100
78+
request = Request(method=Method.GET, endpoint="/_api/collection", data=data)
79+
_ = await connection.send_request(request)
80+
assert len(request.data) < len(data)
81+
assert zlib.decompress(request.data) == data
82+
assert request.headers["content-encoding"] == "deflate"
83+
assert request.headers["accept-encoding"] == "deflate"
84+
85+
86+
@pytest.mark.asyncio
87+
async def test_BasicConnection_prep_response_bad_response(
88+
client_session, url, sys_db_name
89+
):
90+
client = AioHTTPClient()
91+
session = client_session(client, url)
92+
resolver = DefaultHostResolver(1)
93+
94+
connection = BasicConnection(
95+
sessions=[session],
96+
host_resolver=resolver,
97+
http_client=client,
98+
db_name=sys_db_name,
99+
)
100+
101+
request = Request(method=Method.GET, endpoint="/_api/collection")
102+
response = Response(Method.GET, url, {}, 0, "ERROR", b"")
103+
104+
with pytest.raises(ServerConnectionError):
105+
connection.prep_response(request, response)
106+
107+
108+
@pytest.mark.asyncio
109+
async def test_BasicConnection_process_request_connection_aborted(
110+
monkeypatch, client_session, url, sys_db_name, root, password
111+
):
112+
client = AioHTTPClient()
113+
session = client_session(client, url)
114+
resolver = DefaultHostResolver(1, 1)
115+
116+
request = Request(method=Method.GET, endpoint="/_api/collection")
117+
118+
async def mock_send_request(*args, **kwargs):
119+
raise ClientConnectionError("test")
120+
121+
monkeypatch.setattr(client, "send_request", mock_send_request)
122+
123+
connection = BasicConnection(
124+
sessions=[session],
125+
host_resolver=resolver,
126+
http_client=client,
127+
db_name=sys_db_name,
128+
auth=Auth(username=root, password=password),
129+
)
130+
131+
with pytest.raises(ConnectionAbortedError):
132+
await connection.process_request(request)

0 commit comments

Comments
 (0)