diff --git a/arangoasync/aql.py b/arangoasync/aql.py index 021c054..0e00c57 100644 --- a/arangoasync/aql.py +++ b/arangoasync/aql.py @@ -1,4 +1,4 @@ -__all__ = ["AQL"] +__all__ = ["AQL", "AQLQueryCache"] from typing import Optional @@ -6,6 +6,10 @@ from arangoasync.cursor import Cursor from arangoasync.errno import HTTP_NOT_FOUND from arangoasync.exceptions import ( + AQLCacheClearError, + AQLCacheConfigureError, + AQLCacheEntriesError, + AQLCachePropertiesError, AQLQueryClearError, AQLQueryExecuteError, AQLQueryExplainError, @@ -23,6 +27,7 @@ from arangoasync.typings import ( Json, Jsons, + QueryCacheProperties, QueryExplainOptions, QueryProperties, QueryTrackingConfiguration, @@ -30,6 +35,188 @@ ) +class AQLQueryCache: + """AQL Query Cache API wrapper. + + Args: + executor: API executor. Required to execute the API requests. + """ + + def __init__(self, executor: ApiExecutor) -> None: + self._executor = executor + + @property + def name(self) -> str: + """Return the name of the current database.""" + return self._executor.db_name + + @property + def serializer(self) -> Serializer[Json]: + """Return the serializer.""" + return self._executor.serializer + + @property + def deserializer(self) -> Deserializer[Json, Jsons]: + """Return the deserializer.""" + return self._executor.deserializer + + def __repr__(self) -> str: + return f"" + + async def entries(self) -> Result[Jsons]: + """Return a list of all AQL query results cache entries. + + + Returns: + list: List of AQL query results cache entries. + + Raises: + AQLCacheEntriesError: If retrieval fails. + + References: + - `list-the-entries-of-the-aql-query-results-cache `__ + """ # noqa: E501 + request = Request(method=Method.GET, endpoint="/_api/query-cache/entries") + + def response_handler(resp: Response) -> Jsons: + if not resp.is_success: + raise AQLCacheEntriesError(resp, request) + return self.deserializer.loads_many(resp.raw_body) + + return await self._executor.execute(request, response_handler) + + async def plan_entries(self) -> Result[Jsons]: + """Return a list of all AQL query plan cache entries. + + Returns: + list: List of AQL query plan cache entries. + + Raises: + AQLCacheEntriesError: If retrieval fails. + + References: + - `list-the-entries-of-the-aql-query-plan-cache `__ + """ # noqa: E501 + request = Request(method=Method.GET, endpoint="/_api/query-plan-cache") + + def response_handler(resp: Response) -> Jsons: + if not resp.is_success: + raise AQLCacheEntriesError(resp, request) + return self.deserializer.loads_many(resp.raw_body) + + return await self._executor.execute(request, response_handler) + + async def clear(self) -> Result[None]: + """Clear the AQL query results cache. + + Raises: + AQLCacheClearError: If clearing the cache fails. + + References: + - `clear-the-aql-query-results-cache `__ + """ # noqa: E501 + request = Request(method=Method.DELETE, endpoint="/_api/query-cache") + + def response_handler(resp: Response) -> None: + if not resp.is_success: + raise AQLCacheClearError(resp, request) + + return await self._executor.execute(request, response_handler) + + async def clear_plan(self) -> Result[None]: + """Clear the AQL query plan cache. + + Raises: + AQLCacheClearError: If clearing the cache fails. + + References: + - `clear-the-aql-query-plan-cache `__ + """ # noqa: E501 + request = Request(method=Method.DELETE, endpoint="/_api/query-plan-cache") + + def response_handler(resp: Response) -> None: + if not resp.is_success: + raise AQLCacheClearError(resp, request) + + return await self._executor.execute(request, response_handler) + + async def properties(self) -> Result[QueryCacheProperties]: + """Return the current AQL query results cache configuration. + + Returns: + QueryCacheProperties: Current AQL query cache properties. + + Raises: + AQLCachePropertiesError: If retrieval fails. + + References: + - `get-the-aql-query-results-cache-configuration `__ + """ # noqa: E501 + request = Request(method=Method.GET, endpoint="/_api/query-cache/properties") + + def response_handler(resp: Response) -> QueryCacheProperties: + if not resp.is_success: + raise AQLCachePropertiesError(resp, request) + return QueryCacheProperties(self.deserializer.loads(resp.raw_body)) + + return await self._executor.execute(request, response_handler) + + async def configure( + self, + mode: Optional[str] = None, + max_results: Optional[int] = None, + max_results_size: Optional[int] = None, + max_entry_size: Optional[int] = None, + include_system: Optional[bool] = None, + ) -> Result[QueryCacheProperties]: + """Configure the AQL query results cache. + + Args: + mode (str | None): Cache mode. Allowed values are `"off"`, `"on"`, + and `"demand"`. + max_results (int | None): Max number of query results stored per + database-specific cache. + max_results_size (int | None): Max cumulative size of query results stored + per database-specific cache. + max_entry_size (int | None): Max entry size of each query result stored per + database-specific cache. + include_system (bool | None): Store results of queries in system collections. + + Returns: + QueryCacheProperties: Updated AQL query cache properties. + + Raises: + AQLCacheConfigureError: If setting the configuration fails. + + References: + - `set-the-aql-query-results-cache-configuration `__ + """ # noqa: E501 + data: Json = dict() + if mode is not None: + data["mode"] = mode + if max_results is not None: + data["maxResults"] = max_results + if max_results_size is not None: + data["maxResultsSize"] = max_results_size + if max_entry_size is not None: + data["maxEntrySize"] = max_entry_size + if include_system is not None: + data["includeSystem"] = include_system + + request = Request( + method=Method.PUT, + endpoint="/_api/query-cache/properties", + data=self.serializer.dumps(data), + ) + + def response_handler(resp: Response) -> QueryCacheProperties: + if not resp.is_success: + raise AQLCacheConfigureError(resp, request) + return QueryCacheProperties(self.deserializer.loads(resp.raw_body)) + + return await self._executor.execute(request, response_handler) + + class AQL: """AQL (ArangoDB Query Language) API wrapper. @@ -58,6 +245,11 @@ def deserializer(self) -> Deserializer[Json, Jsons]: """Return the deserializer.""" return self._executor.deserializer + @property + def cache(self) -> AQLQueryCache: + """Return the AQL Query Cache API wrapper.""" + return AQLQueryCache(self._executor) + def __repr__(self) -> str: return f"" diff --git a/arangoasync/database.py b/arangoasync/database.py index 3f91c56..12913db 100644 --- a/arangoasync/database.py +++ b/arangoasync/database.py @@ -27,6 +27,7 @@ PermissionResetError, PermissionUpdateError, ServerStatusError, + ServerVersionError, TransactionAbortError, TransactionCommitError, TransactionExecuteError, @@ -1189,6 +1190,32 @@ def response_handler(resp: Response) -> Any: return await self._executor.execute(request, response_handler) + async def version(self, details: bool = False) -> Result[Json]: + """Return the server version information. + + Args: + details (bool): If `True`, return detailed version information. + + Returns: + dict: Server version information. + + Raises: + ServerVersionError: If the operation fails on the server side. + + References: + - `get-the-server-version `__ + """ # noqa: E501 + request = Request( + method=Method.GET, endpoint="/_api/version", params={"details": details} + ) + + def response_handler(resp: Response) -> Json: + if not resp.is_success: + raise ServerVersionError(resp, request) + return self.deserializer.loads(resp.raw_body) + + return await self._executor.execute(request, response_handler) + class StandardDatabase(Database): """Standard database API wrapper. diff --git a/arangoasync/exceptions.py b/arangoasync/exceptions.py index 6cf31c5..ff3e0d1 100644 --- a/arangoasync/exceptions.py +++ b/arangoasync/exceptions.py @@ -71,6 +71,22 @@ def __init__( self.http_headers = resp.headers +class AQLCacheClearError(ArangoServerError): + """Failed to clear the query cache.""" + + +class AQLCacheConfigureError(ArangoServerError): + """Failed to configure query cache properties.""" + + +class AQLCacheEntriesError(ArangoServerError): + """Failed to retrieve AQL cache entries.""" + + +class AQLCachePropertiesError(ArangoServerError): + """Failed to retrieve query cache properties.""" + + class AQLQueryClearError(ArangoServerError): """Failed to clear slow AQL queries.""" @@ -251,6 +267,10 @@ class ServerStatusError(ArangoServerError): """Failed to retrieve server status.""" +class ServerVersionError(ArangoServerError): + """Failed to retrieve server version.""" + + class TransactionAbortError(ArangoServerError): """Failed to abort transaction.""" diff --git a/arangoasync/typings.py b/arangoasync/typings.py index 496e5ca..a24367c 100644 --- a/arangoasync/typings.py +++ b/arangoasync/typings.py @@ -1096,6 +1096,9 @@ class QueryProperties(JsonWrapper): store intermediate and final results temporarily on disk if the number of rows produced by the query exceeds the specified value. stream (bool | None): Can be enabled to execute the query lazily. + use_plan_cache (bool | None): Set this option to `True` to utilize + a cached query plan or add the execution plan of this query to the + cache if it’s not in the cache yet. Example: .. code-block:: json @@ -1136,6 +1139,7 @@ def __init__( spill_over_threshold_memory_usage: Optional[int] = None, spill_over_threshold_num_rows: Optional[int] = None, stream: Optional[bool] = None, + use_plan_cache: Optional[bool] = None, ) -> None: data: Json = dict() if allow_dirty_reads is not None: @@ -1178,6 +1182,8 @@ def __init__( data["spillOverThresholdNumRows"] = spill_over_threshold_num_rows if stream is not None: data["stream"] = stream + if use_plan_cache is not None: + data["usePlanCache"] = use_plan_cache super().__init__(data) @property @@ -1260,6 +1266,10 @@ def spill_over_threshold_num_rows(self) -> Optional[int]: def stream(self) -> Optional[bool]: return self._data.get("stream") + @property + def use_plan_cache(self) -> Optional[bool]: + return self._data.get("usePlanCache") + class QueryExecutionPlan(JsonWrapper): """The execution plan of an AQL query. @@ -1598,3 +1608,46 @@ def max_plans(self) -> Optional[int]: @property def optimizer(self) -> Optional[Json]: return self._data.get("optimizer") + + +class QueryCacheProperties(JsonWrapper): + """AQL Cache Configuration. + + Example: + .. code-block:: json + + { + "mode" : "demand", + "maxResults" : 128, + "maxResultsSize" : 268435456, + "maxEntrySize" : 16777216, + "includeSystem" : false + } + + References: + - `get-the-aql-query-results-cache-configuration `__ + - `set-the-aql-query-results-cache-configuration `__ + """ # noqa: E501 + + def __init__(self, data: Json) -> None: + super().__init__(data) + + @property + def mode(self) -> str: + return cast(str, self._data.get("mode", "")) + + @property + def max_results(self) -> int: + return cast(int, self._data.get("maxResults", 0)) + + @property + def max_results_size(self) -> int: + return cast(int, self._data.get("maxResultsSize", 0)) + + @property + def max_entry_size(self) -> int: + return cast(int, self._data.get("maxEntrySize", 0)) + + @property + def include_system(self) -> bool: + return cast(bool, self._data.get("includeSystem", False)) diff --git a/tests/conftest.py b/tests/conftest.py index 65846e7..e91a591 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -3,6 +3,7 @@ import pytest import pytest_asyncio +from packaging import version from arangoasync.auth import Auth, JwtToken from arangoasync.client import ArangoClient @@ -21,6 +22,7 @@ class GlobalData: username: str = generate_username() cluster: bool = False enterprise: bool = False + db_version: version = version.parse("0.0.0") global_data = GlobalData() @@ -63,6 +65,19 @@ def pytest_configure(config): global_data.cluster = config.getoption("cluster") global_data.enterprise = config.getoption("enterprise") + async def get_db_version(): + async with ArangoClient(hosts=global_data.url) as client: + sys_db = await client.db( + global_data.sys_db_name, + auth_method="basic", + auth=Auth(global_data.root, global_data.password), + verify=False, + ) + db_version = (await sys_db.version())["version"] + global_data.db_version = version.parse(db_version.split("-")[0]) + + asyncio.run(get_db_version()) + @pytest.fixture def url(): @@ -213,6 +228,11 @@ def bad_col(db): return db.collection(col_name) +@pytest.fixture +def db_version(): + return global_data.db_version + + @pytest_asyncio.fixture(scope="session", autouse=True) async def teardown(): yield diff --git a/tests/test_aql.py b/tests/test_aql.py index 74db7b0..9176974 100644 --- a/tests/test_aql.py +++ b/tests/test_aql.py @@ -2,9 +2,14 @@ import time import pytest +from packaging import version -from arangoasync.errno import QUERY_PARSE +from arangoasync.errno import FORBIDDEN, QUERY_PARSE from arangoasync.exceptions import ( + AQLCacheClearError, + AQLCacheConfigureError, + AQLCacheEntriesError, + AQLCachePropertiesError, AQLQueryClearError, AQLQueryExecuteError, AQLQueryExplainError, @@ -190,3 +195,84 @@ async def test_query_rules(db, bad_db): with pytest.raises(AQLQueryRulesGetError): _ = await bad_db.aql.query_rules() + + +@pytest.mark.asyncio +async def test_cache_results_management(db, bad_db, doc_col, docs, cluster): + if cluster: + pytest.skip("Cluster mode does not support query rest cache management") + + aql = db.aql + cache = aql.cache + + # Sanity check, just see if the response is OK. + _ = await cache.properties() + with pytest.raises(AQLCachePropertiesError) as err: + _ = await bad_db.aql.cache.properties() + assert err.value.error_code == FORBIDDEN + + # Turn on caching + result = await cache.configure(mode="on") + assert result.mode == "on" + result = await cache.properties() + assert result.mode == "on" + with pytest.raises(AQLCacheConfigureError) as err: + _ = await bad_db.aql.cache.configure(mode="on") + assert err.value.error_code == FORBIDDEN + + # Run a simple query to use the cache + await doc_col.insert(docs[0]) + _ = await aql.execute( + query="FOR doc IN @@collection RETURN doc", + bind_vars={"@collection": doc_col.name}, + ) + + # Check the entries + entries = await cache.entries() + assert isinstance(entries, list) + assert len(entries) > 0 + + with pytest.raises(AQLCacheEntriesError) as err: + _ = await bad_db.aql.cache.entries() + assert err.value.error_code == FORBIDDEN + + # Clear the cache + await cache.clear() + entries = await cache.entries() + assert len(entries) == 0 + with pytest.raises(AQLCacheClearError) as err: + await bad_db.aql.cache.clear() + assert err.value.error_code == FORBIDDEN + + +@pytest.mark.asyncio +async def test_cache_plan_management(db, bad_db, doc_col, docs, db_version): + if db_version < version.parse("3.12.4"): + pytest.skip("Query plan cache is supported in ArangoDB 3.12.4+") + + aql = db.aql + cache = aql.cache + + # Run a simple query to use the cache + await doc_col.insert(docs[0]) + _ = await aql.execute( + query="FOR doc IN @@collection RETURN doc", + bind_vars={"@collection": doc_col.name}, + options={"usePlanCache": True}, + ) + + # Check the entries + entries = await cache.plan_entries() + assert isinstance(entries, list) + assert len(entries) > 0 + with pytest.raises(AQLCacheEntriesError) as err: + _ = await bad_db.aql.cache.plan_entries() + assert err.value.error_code == FORBIDDEN + + # Clear the cache + await cache.clear_plan() + entries = await cache.plan_entries() + assert len(entries) == 0 + with pytest.raises(AQLCacheClearError) as err: + await bad_db.aql.cache.clear_plan() + assert err.value.error_code == FORBIDDEN diff --git a/tests/test_database.py b/tests/test_database.py index fc8b2bc..eb7daa3 100644 --- a/tests/test_database.py +++ b/tests/test_database.py @@ -14,6 +14,7 @@ JWTSecretListError, JWTSecretReloadError, ServerStatusError, + ServerVersionError, ) from arangoasync.typings import CollectionType, KeyOptions, UserInfo from tests.helpers import generate_col_name, generate_db_name, generate_username @@ -48,6 +49,12 @@ async def test_database_misc_methods(sys_db, db, bad_db, cluster): with pytest.raises(JWTSecretReloadError): await bad_db.reload_jwt_secrets() + # Version + version = await sys_db.version() + assert version["version"].startswith("3.") + with pytest.raises(ServerVersionError): + await bad_db.version() + @pytest.mark.asyncio async def test_create_drop_database( diff --git a/tests/test_typings.py b/tests/test_typings.py index 218f421..9d8e2d5 100644 --- a/tests/test_typings.py +++ b/tests/test_typings.py @@ -6,6 +6,7 @@ CollectionType, JsonWrapper, KeyOptions, + QueryCacheProperties, QueryExecutionExtra, QueryExecutionPlan, QueryExecutionProfile, @@ -156,6 +157,7 @@ def test_QueryProperties(): spill_over_threshold_memory_usage=10485760, spill_over_threshold_num_rows=100000, stream=True, + use_plan_cache=True, ) assert properties.allow_dirty_reads is True assert properties.allow_retry is False @@ -177,6 +179,7 @@ def test_QueryProperties(): assert properties.spill_over_threshold_memory_usage == 10485760 assert properties.spill_over_threshold_num_rows == 100000 assert properties.stream is True + assert properties.use_plan_cache is True def test_QueryExecutionPlan(): @@ -313,3 +316,17 @@ def test_QueryExplainOptions(): assert options.all_plans is True assert options.max_plans == 5 assert options.optimizer == {"rules": ["-all", "+use-index-range"]} + + +def test_QueryCacheProperties(): + data = { + "mode": "demand", + "maxResults": 128, + "maxEntrySize": 1024, + "includeSystem": False, + } + cache_properties = QueryCacheProperties(data) + assert cache_properties._data["mode"] == "demand" + assert cache_properties._data["maxResults"] == 128 + assert cache_properties._data["maxEntrySize"] == 1024 + assert cache_properties._data["includeSystem"] is False