diff --git a/arangoasync/client.py b/arangoasync/client.py index ab54426..83b2f67 100644 --- a/arangoasync/client.py +++ b/arangoasync/client.py @@ -43,7 +43,7 @@ class ArangoClient: responses. Enable it by passing an instance of :class:`DefaultCompressionManager ` - or a subclass of :class:`CompressionManager + or a custom subclass of :class:`CompressionManager `. Raises: @@ -143,8 +143,8 @@ async def db( auth (Auth | None): Login information. token (JwtToken | None): JWT token. verify (bool): Verify the connection by sending a test request. - compression (CompressionManager | None): Supersedes the client-level - compression settings. + compression (CompressionManager | None): If set, supersedes the + client-level compression settings. Returns: Database: Database API wrapper. diff --git a/arangoasync/compression.py b/arangoasync/compression.py index adc3957..f025a3f 100644 --- a/arangoasync/compression.py +++ b/arangoasync/compression.py @@ -86,17 +86,16 @@ class DefaultCompressionManager(CompressionManager): Args: threshold (int): Will compress requests to the server if the size of the request body (in bytes) is at least the value of this option. - Setting it to -1 will disable request compression (default). + Setting it to -1 will disable request compression. level (int): Compression level. Defaults to 6. - accept (str | None): Accepted encoding. By default, there is - no compression of responses. + accept (str | None): Accepted encoding. Can be disabled by setting it to `None`. """ def __init__( self, - threshold: int = -1, + threshold: int = 1024, level: int = 6, - accept: Optional[AcceptEncoding] = None, + accept: Optional[AcceptEncoding] = AcceptEncoding.DEFLATE, ) -> None: self._threshold = threshold self._level = level @@ -132,7 +131,7 @@ def content_encoding(self) -> str: return self._content_encoding def needs_compression(self, data: str | bytes) -> bool: - return self._threshold != -1 and len(data) >= self._threshold + return len(data) >= self._threshold def compress(self, data: str | bytes) -> bytes: if isinstance(data, bytes): diff --git a/arangoasync/connection.py b/arangoasync/connection.py index 0d342de..68a021f 100644 --- a/arangoasync/connection.py +++ b/arangoasync/connection.py @@ -14,7 +14,7 @@ from arangoasync import errno, logger from arangoasync.auth import Auth, JwtToken -from arangoasync.compression import CompressionManager, DefaultCompressionManager +from arangoasync.compression import CompressionManager from arangoasync.exceptions import ( AuthHeaderError, ClientConnectionAbortedError, @@ -52,7 +52,7 @@ def __init__( self._host_resolver = host_resolver self._http_client = http_client self._db_name = db_name - self._compression = compression or DefaultCompressionManager() + self._compression = compression @property def db_name(self) -> str: @@ -100,6 +100,38 @@ def prep_response(request: Request, resp: Response) -> Response: resp.error_message = body.get("errorMessage") return resp + def compress_request(self, request: Request) -> bool: + """Compress request if needed. + + Additionally, the server may be instructed to compress the response. + The decision to compress the request is based on the compression strategy + passed during the connection initialization. + The request headers and may be modified as a result of this operation. + + Args: + request (Request): Request to be compressed. + + Returns: + bool: True if compression settings were applied. + """ + if self._compression is None: + return False + + result: bool = False + if request.data is not None and self._compression.needs_compression( + request.data + ): + request.data = self._compression.compress(request.data) + request.headers["content-encoding"] = self._compression.content_encoding + result = True + + accept_encoding: str | None = self._compression.accept_encoding + if accept_encoding is not None: + request.headers["accept-encoding"] = accept_encoding + result = True + + return result + async def process_request(self, request: Request) -> Response: """Process request, potentially trying multiple hosts. @@ -198,15 +230,7 @@ async def send_request(self, request: Request) -> Response: ArangoClientError: If an error occurred from the client side. ArangoServerError: If an error occurred from the server side. """ - if request.data is not None and self._compression.needs_compression( - request.data - ): - request.data = self._compression.compress(request.data) - request.headers["content-encoding"] = self._compression.content_encoding - - accept_encoding: str | None = self._compression.accept_encoding - if accept_encoding is not None: - request.headers["accept-encoding"] = accept_encoding + self.compress_request(request) if self._auth: request.auth = self._auth @@ -335,6 +359,7 @@ async def send_request(self, request: Request) -> Response: raise AuthHeaderError("Failed to generate authorization header.") request.headers["authorization"] = self._auth_header + self.compress_request(request) resp = await self.process_request(request) if ( @@ -416,6 +441,7 @@ async def send_request(self, request: Request) -> Response: if self._auth_header is None: raise AuthHeaderError("Failed to generate authorization header.") request.headers["authorization"] = self._auth_header + self.compress_request(request) resp = await self.process_request(request) self.raise_for_status(request, resp)