From dfc29507061b2e852a8c75184fb08f9ac1d3254c Mon Sep 17 00:00:00 2001 From: Alex Petenchea Date: Sun, 25 Aug 2024 20:03:33 +0300 Subject: [PATCH 1/2] Improved tests --- arangoasync/compression.py | 42 +++++++++++++----- arangoasync/connection.py | 12 +++--- arangoasync/http.py | 4 +- arangoasync/request.py | 24 ++++------- tests/test_compression.py | 23 +++++++++- tests/test_connection.py | 87 +++++++++++++++++++++++++++++++++++++- 6 files changed, 154 insertions(+), 38 deletions(-) diff --git a/arangoasync/compression.py b/arangoasync/compression.py index 1151149..22b52d3 100644 --- a/arangoasync/compression.py +++ b/arangoasync/compression.py @@ -53,6 +53,7 @@ def compress(self, data: str | bytes) -> bytes: """ raise NotImplementedError + @property @abstractmethod def content_encoding(self) -> str: """Return the content encoding. @@ -65,6 +66,7 @@ def content_encoding(self) -> str: """ raise NotImplementedError + @property @abstractmethod def accept_encoding(self) -> str | None: """Return the accept encoding. @@ -101,18 +103,38 @@ def __init__( self._content_encoding = ContentEncoding.DEFLATE.name.lower() self._accept_encoding = accept.name.lower() if accept else None - def needs_compression(self, data: str | bytes) -> bool: - return self._threshold != -1 and len(data) >= self._threshold + @property + def threshold(self) -> int: + return self._threshold - def compress(self, data: str | bytes) -> bytes: - if data is not None: - if isinstance(data, bytes): - return zlib.compress(data, self._level) - return zlib.compress(data.encode("utf-8"), self._level) - return b"" + @threshold.setter + def threshold(self, value: int) -> None: + self._threshold = value - def content_encoding(self) -> str: - return self._content_encoding + @property + def level(self) -> int: + return self._level + + @level.setter + def level(self, value: int) -> None: + self._level = value + @property def accept_encoding(self) -> str | None: return self._accept_encoding + + @accept_encoding.setter + def accept_encoding(self, value: AcceptEncoding | None) -> None: + self._accept_encoding = value.name.lower() if value else None + + @property + def content_encoding(self) -> str: + return self._content_encoding + + def needs_compression(self, data: str | bytes) -> bool: + return self._threshold != -1 and len(data) >= self._threshold + + def compress(self, data: str | bytes) -> bytes: + if isinstance(data, bytes): + return zlib.compress(data, self._level) + return zlib.compress(data.encode("utf-8"), self._level) diff --git a/arangoasync/connection.py b/arangoasync/connection.py index 0fdd2bc..288c370 100644 --- a/arangoasync/connection.py +++ b/arangoasync/connection.py @@ -64,8 +64,10 @@ def prep_response(self, request: Request, resp: Response) -> Response: ServerConnectionError: If the response status code is not successful. """ resp.is_success = 200 <= resp.status_code < 300 + if resp.status_code in {401, 403}: + raise ServerConnectionError(resp, request, "Authentication failed.") if not resp.is_success: - raise ServerConnectionError(resp, request) + raise ServerConnectionError(resp, request, "Bad server response.") return resp async def process_request(self, request: Request) -> Response: @@ -110,10 +112,6 @@ async def ping(self) -> int: """ request = Request(method=Method.GET, endpoint="/_api/collection") resp = await self.send_request(request) - if resp.status_code in {401, 403}: - raise ServerConnectionError(resp, request, "Authentication failed.") - if not resp.is_success: - raise ServerConnectionError(resp, request, "Bad server response.") return resp.status_code @abstractmethod @@ -161,9 +159,9 @@ async def send_request(self, request: Request) -> Response: request.data ): request.data = self._compression.compress(request.data) - request.headers["content-encoding"] = self._compression.content_encoding() + request.headers["content-encoding"] = self._compression.content_encoding - accept_encoding: str | None = self._compression.accept_encoding() + accept_encoding: str | None = self._compression.accept_encoding if accept_encoding is not None: request.headers["accept-encoding"] = accept_encoding diff --git a/arangoasync/http.py b/arangoasync/http.py index e80dc91..7fba5c2 100644 --- a/arangoasync/http.py +++ b/arangoasync/http.py @@ -151,8 +151,8 @@ async def send_request( async with session.request( request.method.name, request.endpoint, - headers=request.headers, - params=request.params, + headers=request.normalized_headers(), + params=request.normalized_params(), data=request.data, auth=auth, ) as response: diff --git a/arangoasync/request.py b/arangoasync/request.py index 0c183d5..8890468 100644 --- a/arangoasync/request.py +++ b/arangoasync/request.py @@ -63,18 +63,14 @@ def __init__( ) -> None: self.method: Method = method self.endpoint: str = endpoint - self.headers: RequestHeaders = self._normalize_headers(headers) - self.params: Params = self._normalize_params(params) + self.headers: RequestHeaders = headers or dict() + self.params: Params = params or dict() self.data: Optional[bytes] = data self.auth: Optional[Auth] = auth - @staticmethod - def _normalize_headers(headers: Optional[RequestHeaders]) -> RequestHeaders: + def normalized_headers(self) -> RequestHeaders: """Normalize request headers. - Parameters: - headers (dict | None): Request headers. - Returns: dict: Normalized request headers. """ @@ -85,26 +81,22 @@ def _normalize_headers(headers: Optional[RequestHeaders]) -> RequestHeaders: "x-arango-driver": driver_header, } - if headers is not None: - for key, value in headers.items(): + if self.headers is not None: + for key, value in self.headers.items(): normalized_headers[key.lower()] = value return normalized_headers - @staticmethod - def _normalize_params(params: Optional[Params]) -> Params: + def normalized_params(self) -> Params: """Normalize URL parameters. - Parameters: - params (dict | None): URL parameters. - Returns: dict: Normalized URL parameters. """ normalized_params: Params = {} - if params is not None: - for key, value in params.items(): + if self.params is not None: + for key, value in self.params.items(): if isinstance(value, bool): value = int(value) normalized_params[key] = str(value) diff --git a/tests/test_compression.py b/tests/test_compression.py index 26e9baa..de4c5dd 100644 --- a/tests/test_compression.py +++ b/tests/test_compression.py @@ -16,5 +16,24 @@ def test_DefaultCompressionManager_compress(): data = "a" * 10 + "b" * 10 assert manager.needs_compression(data) assert len(manager.compress(data)) < len(data) - assert manager.content_encoding() == "deflate" - assert manager.accept_encoding() == "deflate" + assert manager.content_encoding == "deflate" + assert manager.accept_encoding == "deflate" + data = b"a" * 10 + b"b" * 10 + assert manager.needs_compression(data) + assert len(manager.compress(data)) < len(data) + + +def test_DefaultCompressionManager_properties(): + manager = DefaultCompressionManager( + threshold=1, level=9, accept=AcceptEncoding.DEFLATE + ) + assert manager.threshold == 1 + assert manager.level == 9 + assert manager.accept_encoding == "deflate" + assert manager.content_encoding == "deflate" + manager.threshold = 10 + assert manager.threshold == 10 + manager.level = 2 + assert manager.level == 2 + manager.accept_encoding = AcceptEncoding.GZIP + assert manager.accept_encoding == "gzip" diff --git a/tests/test_connection.py b/tests/test_connection.py index 40a525d..2c22855 100644 --- a/tests/test_connection.py +++ b/tests/test_connection.py @@ -1,10 +1,19 @@ +import zlib + import pytest from arangoasync.auth import Auth +from arangoasync.compression import AcceptEncoding, DefaultCompressionManager from arangoasync.connection import BasicConnection -from arangoasync.exceptions import ServerConnectionError +from arangoasync.exceptions import ( + ClientConnectionError, + ConnectionAbortedError, + ServerConnectionError, +) from arangoasync.http import AioHTTPClient +from arangoasync.request import Method, Request from arangoasync.resolver import DefaultHostResolver +from arangoasync.response import Response @pytest.mark.asyncio @@ -40,5 +49,81 @@ async def test_BasicConnection_ping_success( auth=Auth(username=root, password=password), ) + assert connection.db_name == sys_db_name status_code = await connection.ping() assert status_code == 200 + + +@pytest.mark.asyncio +async def test_BasicConnection_with_compression( + client_session, url, sys_db_name, root, password +): + client = AioHTTPClient() + session = client_session(client, url) + resolver = DefaultHostResolver(1) + compression = DefaultCompressionManager( + threshold=2, level=5, accept=AcceptEncoding.DEFLATE + ) + + connection = BasicConnection( + sessions=[session], + host_resolver=resolver, + http_client=client, + db_name=sys_db_name, + auth=Auth(username=root, password=password), + compression=compression, + ) + + data = b"a" * 100 + request = Request(method=Method.GET, endpoint="/_api/collection", data=data) + _ = await connection.send_request(request) + assert len(request.data) < len(data) + assert zlib.decompress(request.data) == data + assert request.headers["content-encoding"] == "deflate" + assert request.headers["accept-encoding"] == "deflate" + + +def test_BasicConnection_prep_response_bad_response(client_session, url, sys_db_name): + client = AioHTTPClient() + session = client_session(client, url) + resolver = DefaultHostResolver(1) + + connection = BasicConnection( + sessions=[session], + host_resolver=resolver, + http_client=client, + db_name=sys_db_name, + ) + + request = Request(method=Method.GET, endpoint="/_api/collection") + response = Response(Method.GET, url, {}, 0, "ERROR", b"") + + with pytest.raises(ServerConnectionError): + connection.prep_response(request, response) + + +@pytest.mark.asyncio +async def test_BasicConnection_process_request_connection_aborted( + monkeypatch, client_session, url, sys_db_name, root, password +): + client = AioHTTPClient() + session = client_session(client, url) + resolver = DefaultHostResolver(1, 1) + + request = Request(method=Method.GET, endpoint="/_api/collection") + + async def mock_send_request(*args, **kwargs): + raise ClientConnectionError("test") + + monkeypatch.setattr(client, "send_request", mock_send_request) + + connection = BasicConnection( + sessions=[session], + host_resolver=resolver, + http_client=client, + db_name=sys_db_name, + auth=Auth(username=root, password=password), + ) + + with pytest.raises(ConnectionAbortedError): + await connection.process_request(request) From 8ecc760c8c2108cdc64734de97664b14b12892f6 Mon Sep 17 00:00:00 2001 From: Alex Petenchea Date: Sun, 25 Aug 2024 20:08:08 +0300 Subject: [PATCH 2/2] Marking test async --- tests/test_connection.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/tests/test_connection.py b/tests/test_connection.py index 2c22855..bf5409a 100644 --- a/tests/test_connection.py +++ b/tests/test_connection.py @@ -83,7 +83,10 @@ async def test_BasicConnection_with_compression( assert request.headers["accept-encoding"] == "deflate" -def test_BasicConnection_prep_response_bad_response(client_session, url, sys_db_name): +@pytest.mark.asyncio +async def test_BasicConnection_prep_response_bad_response( + client_session, url, sys_db_name +): client = AioHTTPClient() session = client_session(client, url) resolver = DefaultHostResolver(1)