From 508cc88da20ef67d726e0afd6f60943266c2af14 Mon Sep 17 00:00:00 2001 From: Alex Petenchea Date: Tue, 29 Oct 2024 20:29:28 +0200 Subject: [PATCH 1/3] Improving REST API coverage --- arangoasync/collection.py | 45 ++++- arangoasync/database.py | 349 +++++++++++++++++++++++++++++++++++- arangoasync/exceptions.py | 40 +++++ arangoasync/executor.py | 4 + arangoasync/typings.py | 366 ++++++++++++++++++++++++++++++++++++-- tests/conftest.py | 6 + tests/test_client.py | 6 +- tests/test_collection.py | 20 +++ tests/test_database.py | 38 +++- tests/test_user.py | 120 ++++++++++++- 10 files changed, 963 insertions(+), 31 deletions(-) create mode 100644 tests/test_collection.py diff --git a/arangoasync/collection.py b/arangoasync/collection.py index 92ae0d4..ff5de47 100644 --- a/arangoasync/collection.py +++ b/arangoasync/collection.py @@ -9,6 +9,7 @@ HTTP_PRECONDITION_FAILED, ) from arangoasync.exceptions import ( + CollectionPropertiesError, DocumentGetError, DocumentInsertError, DocumentParseError, @@ -18,7 +19,7 @@ from arangoasync.request import Method, Request from arangoasync.response import Response from arangoasync.serialization import Deserializer, Serializer -from arangoasync.typings import Json, Params, Result +from arangoasync.typings import CollectionProperties, Json, Params, Result T = TypeVar("T") U = TypeVar("U") @@ -48,9 +49,6 @@ def __init__( self._doc_deserializer = doc_deserializer self._id_prefix = f"{self._name}/" - def __repr__(self) -> str: - return f"" - def _validate_id(self, doc_id: str) -> str: """Check the collection name in the document ID. @@ -148,6 +146,15 @@ def name(self) -> str: """ return self._name + @property + def db_name(self) -> str: + """Return the name of the current database. + + Returns: + str: Database name. + """ + return self._executor.db_name + class StandardCollection(Collection[T, U, V]): """Standard collection API wrapper. @@ -168,6 +175,33 @@ def __init__( ) -> None: super().__init__(executor, name, doc_serializer, doc_deserializer) + def __repr__(self) -> str: + return f"" + + async def properties(self) -> Result[CollectionProperties]: + """Return the full properties of the current collection. + + Returns: + CollectionProperties: Properties. + + Raises: + CollectionPropertiesError: If retrieval fails. + + References: + - `get-the-properties-of-a-collection `__ + """ # noqa: E501 + request = Request( + method=Method.GET, + endpoint=f"/_api/collection/{self.name}/properties", + ) + + def response_handler(resp: Response) -> CollectionProperties: + if not resp.is_success: + raise CollectionPropertiesError(resp, request) + return CollectionProperties(self._executor.deserialize(resp.raw_body)) + + return await self._executor.execute(request, response_handler) + async def get( self, document: str | Json, @@ -269,6 +303,9 @@ async def insert( bool | dict: Document metadata (e.g. document id, key, revision) or `True` if **silent** is set to `True`. + Raises: + DocumentInsertError: If insertion fails. + References: - `create-a-document `__ """ # noqa: E501 diff --git a/arangoasync/database.py b/arangoasync/database.py index f420244..f5cb8e4 100644 --- a/arangoasync/database.py +++ b/arangoasync/database.py @@ -16,11 +16,20 @@ DatabaseCreateError, DatabaseDeleteError, DatabaseListError, + DatabasePropertiesError, + JWTSecretListError, + JWTSecretReloadError, + PermissionGetError, + PermissionListError, + PermissionResetError, + PermissionUpdateError, ServerStatusError, UserCreateError, UserDeleteError, UserGetError, UserListError, + UserReplaceError, + UserUpdateError, ) from arangoasync.executor import ApiExecutor, DefaultApiExecutor from arangoasync.request import Method, Request @@ -29,6 +38,7 @@ from arangoasync.typings import ( CollectionInfo, CollectionType, + DatabaseProperties, Json, Jsons, KeyOptions, @@ -74,6 +84,29 @@ def deserializer(self) -> Deserializer[Json, Jsons]: """Return the deserializer.""" return self._executor.deserializer + async def properties(self) -> Result[DatabaseProperties]: + """Return database properties. + + Returns: + DatabaseProperties: Properties of the current database. + + Raises: + DatabasePropertiesError: If retrieval fails. + + References: + - `get-information-about-the-current-database `__ + """ # noqa: E501 + request = Request(method=Method.GET, endpoint="/_api/database/current") + + def response_handler(resp: Response) -> DatabaseProperties: + if not resp.is_success: + raise DatabasePropertiesError(resp, request) + return DatabaseProperties( + self.deserializer.loads(resp.raw_body), strip_result=True + ) + + return await self._executor.execute(request, response_handler) + async def status(self) -> Result[ServerStatusInformation]: """Query the server status. @@ -123,6 +156,31 @@ def response_handler(resp: Response) -> List[str]: return await self._executor.execute(request, response_handler) + async def databases_accessible_to_user(self) -> Result[List[str]]: + """Return the names of all databases accessible to the current user. + + Note: + This method can only be executed in the **_system** database. + + Returns: + list: Database names. + + Raises: + DatabaseListError: If retrieval fails. + + References: + - `list-the-accessible-databases `__ + """ # noqa: E501 + request = Request(method=Method.GET, endpoint="/_api/database/user") + + def response_handler(resp: Response) -> List[str]: + if resp.is_success: + body = self.deserializer.loads(resp.raw_body) + return cast(List[str], body["result"]) + raise DatabaseListError(resp, request) + + return await self._executor.execute(request, response_handler) + async def has_database(self, name: str) -> Result[bool]: """Check if a database exists. @@ -643,10 +701,7 @@ def response_handler(resp: Response) -> Sequence[UserInfo]: return await self._executor.execute(request, response_handler) - async def create_user( - self, - user: UserInfo, - ) -> Result[UserInfo]: + async def create_user(self, user: UserInfo | Json) -> Result[UserInfo]: """Create a new user. Args: @@ -673,7 +728,7 @@ async def create_user( if not user.user: raise ValueError("Username is required.") - data: Json = user.to_dict() + data: Json = user.format(UserInfo.user_management_formatter) request = Request( method=Method.POST, endpoint="/_api/user", @@ -692,6 +747,86 @@ def response_handler(resp: Response) -> UserInfo: return await self._executor.execute(request, response_handler) + async def replace_user(self, user: UserInfo | Json) -> Result[UserInfo]: + """Replace the data of an existing user. + + Args: + user (UserInfo | dict): New user information. + + Returns: + UserInfo: New user details. + + Raises: + ValueError: If the username is missing. + UserReplaceError: If the operation fails. + + References: + - `replace-a-user `__ + """ # noqa: E501 + if isinstance(user, dict): + user = UserInfo(**user) + if not user.user: + raise ValueError("Username is required.") + + data: Json = user.format(UserInfo.user_management_formatter) + request = Request( + method=Method.PUT, + endpoint=f"/_api/user/{user.user}", + data=self.serializer.dumps(data), + ) + + def response_handler(resp: Response) -> UserInfo: + if not resp.is_success: + raise UserReplaceError(resp, request) + body = self.deserializer.loads(resp.raw_body) + return UserInfo( + user=body["user"], + active=cast(bool, body.get("active")), + extra=body.get("extra"), + ) + + return await self._executor.execute(request, response_handler) + + async def update_user(self, user: UserInfo | Json) -> Result[UserInfo]: + """Partially modifies the data of an existing user. + + Args: + user (UserInfo | dict): User information. + + Returns: + UserInfo: Updated user details. + + Raises: + ValueError: If the username is missing. + UserUpdateError: If the operation fails. + + References: + - `update-a-user `__ + """ # noqa: E501 + if isinstance(user, dict): + user = UserInfo(**user) + if not user.user: + raise ValueError("Username is required.") + + data: Json = user.format(UserInfo.user_management_formatter) + request = Request( + method=Method.PATCH, + endpoint=f"/_api/user/{user.user}", + data=self.serializer.dumps(data), + ) + + def response_handler(resp: Response) -> UserInfo: + if not resp.is_success: + raise UserUpdateError(resp, request) + body = self.deserializer.loads(resp.raw_body) + return UserInfo( + user=body["user"], + active=cast(bool, body.get("active")), + extra=body.get("extra"), + ) + + return await self._executor.execute(request, response_handler) + async def delete_user( self, username: str, @@ -724,6 +859,210 @@ def response_handler(resp: Response) -> bool: return await self._executor.execute(request, response_handler) + async def permissions(self, username: str, full: bool = True) -> Result[Json]: + """Return user permissions for all databases and collections. + + Args: + username (str): Username. + full (bool): If `True`, the result will contain the permissions for the + databases as well as the permissions for the collections. + + Returns: + dict: User permissions for all databases and (optionally) collections. + + Raises: + PermissionListError: If the operation fails. + + References: + - `list-a-users-accessible-databases `__ + """ # noqa: 501 + request = Request( + method=Method.GET, + endpoint=f"/_api/user/{username}/database", + params={"full": full}, + ) + + def response_handler(resp: Response) -> Json: + if resp.is_success: + result: Json = self.deserializer.loads(resp.raw_body)["result"] + return result + raise PermissionListError(resp, request) + + return await self._executor.execute(request, response_handler) + + async def permission( + self, + username: str, + database: str, + collection: Optional[str] = None, + ) -> Result[str]: + """Return user permission for a specific database or collection. + + Args: + username (str): Username. + database (str): Database name. + collection (str | None): Collection name. + + Returns: + str: User access level. + + Raises: + PermissionGetError: If the operation fails. + + References: + - `get-a-users-database-access-level `__ + - `get-a-users-collection-access-level `__ + """ # noqa: 501 + endpoint = f"/_api/user/{username}/database/{database}" + if collection is not None: + endpoint += f"/{collection}" + request = Request(method=Method.GET, endpoint=endpoint) + + def response_handler(resp: Response) -> str: + if resp.is_success: + return cast(str, self.deserializer.loads(resp.raw_body)["result"]) + raise PermissionGetError(resp, request) + + return await self._executor.execute(request, response_handler) + + async def update_permission( + self, + username: str, + permission: str, + database: str, + collection: Optional[str] = None, + ignore_failure: bool = False, + ) -> Result[bool]: + """Update user permissions for a specific database or collection. + + Args: + username (str): Username. + permission (str): Allowed values are "rw" (administrate), + "ro" (access) and "none" (no access). + database (str): Database to set the access level for. + collection (str | None): Collection to set the access level for. + ignore_failure (bool): Do not raise an exception on failure. + + Returns: + bool: `True` if the operation was successful. + + Raises: + PermissionUpdateError: If the operation fails and `ignore_failure` + is `False`. + + References: + - `set-a-users-database-access-level `__ + - `set-a-users-collection-access-level `__ + """ # noqa: E501 + endpoint = f"/_api/user/{username}/database/{database}" + if collection is not None: + endpoint += f"/{collection}" + + request = Request( + method=Method.PUT, + endpoint=endpoint, + data=self.serializer.dumps({"grant": permission}), + ) + + def response_handler(resp: Response) -> bool: + nonlocal ignore_failure + if resp.is_success: + return True + if ignore_failure: + return False + raise PermissionUpdateError(resp, request) + + return await self._executor.execute(request, response_handler) + + async def reset_permission( + self, + username: str, + database: str, + collection: Optional[str] = None, + ignore_failure: bool = False, + ) -> Result[bool]: + """Reset user permission for a specific database or collection. + + Args: + username (str): Username. + database (str): Database to reset the access level for. + collection (str | None): Collection to reset the access level for. + ignore_failure (bool): Do not raise an exception on failure. + + Returns: + bool: `True` if the operation was successful. + + Raises: + PermissionResetError: If the operation fails and `ignore_failure` + is `False`. + + References: + - `clear-a-users-database-access-level `__ + - `clear-a-users-collection-access-level `__ + """ # noqa: E501 + endpoint = f"/_api/user/{username}/database/{database}" + if collection is not None: + endpoint += f"/{collection}" + + request = Request( + method=Method.DELETE, + endpoint=endpoint, + ) + + def response_handler(resp: Response) -> bool: + nonlocal ignore_failure + if resp.is_success: + return True + if ignore_failure: + return False + raise PermissionResetError(resp, request) + + return await self._executor.execute(request, response_handler) + + async def jwt_secrets(self) -> Result[Json]: + """Return information on currently loaded JWT secrets. + + Returns: + dict: JWT secrets. + + Raises: + JWTSecretListError: If the operation fails. + + References: + - `get-information-about-the-loaded-jwt-secrets `__ + """ # noqa: 501 + request = Request(method=Method.GET, endpoint="/_admin/server/jwt") + + def response_handler(resp: Response) -> Json: + if not resp.is_success: + raise JWTSecretListError(resp, request) + result: Json = self.deserializer.loads(resp.raw_body) + return result + + return await self._executor.execute(request, response_handler) + + async def reload_jwt_secrets(self) -> Result[Json]: + """Hot_reload JWT secrets from disk. + + Returns: + dict: Information on reloaded JWT secrets. + + Raises: + JWTSecretReloadError: If the operation fails. + + References: + - `hot-reload-the-jwt-secrets-from-disk `__ + """ # noqa: 501 + request = Request(method=Method.POST, endpoint="/_admin/server/jwt") + + def response_handler(resp: Response) -> Json: + if not resp.is_success: + raise JWTSecretReloadError(resp, request) + result: Json = self.deserializer.loads(resp.raw_body) + return result + + return await self._executor.execute(request, response_handler) + class StandardDatabase(Database): """Standard database API wrapper.""" diff --git a/arangoasync/exceptions.py b/arangoasync/exceptions.py index 7679d3a..9a5eb1b 100644 --- a/arangoasync/exceptions.py +++ b/arangoasync/exceptions.py @@ -87,6 +87,10 @@ class CollectionListError(ArangoServerError): """Failed to retrieve collections.""" +class CollectionPropertiesError(ArangoServerError): + """Failed to retrieve collection properties.""" + + class ClientConnectionAbortedError(ArangoClientError): """The connection was aborted.""" @@ -107,6 +111,10 @@ class DatabaseListError(ArangoServerError): """Failed to retrieve databases.""" +class DatabasePropertiesError(ArangoServerError): + """Failed to retrieve database properties.""" + + class DeserializationError(ArangoClientError): """Failed to deserialize the server response.""" @@ -131,6 +139,30 @@ class JWTRefreshError(ArangoClientError): """Failed to refresh the JWT token.""" +class JWTSecretListError(ArangoServerError): + """Failed to retrieve information on currently loaded JWT secrets.""" + + +class JWTSecretReloadError(ArangoServerError): + """Failed to reload JWT secrets.""" + + +class PermissionGetError(ArangoServerError): + """Failed to retrieve user permission.""" + + +class PermissionListError(ArangoServerError): + """Failed to list user permissions.""" + + +class PermissionResetError(ArangoServerError): + """Failed to reset user permission.""" + + +class PermissionUpdateError(ArangoServerError): + """Failed to update user permission.""" + + class SerializationError(ArangoClientError): """Failed to serialize the request.""" @@ -157,3 +189,11 @@ class UserGetError(ArangoServerError): class UserListError(ArangoServerError): """Failed to retrieve users.""" + + +class UserReplaceError(ArangoServerError): + """Failed to replace user.""" + + +class UserUpdateError(ArangoServerError): + """Failed to update user.""" diff --git a/arangoasync/executor.py b/arangoasync/executor.py index 48fafe5..175096e 100644 --- a/arangoasync/executor.py +++ b/arangoasync/executor.py @@ -29,6 +29,10 @@ def connection(self) -> Connection: def context(self) -> str: return "default" + @property + def db_name(self) -> str: + return self._conn.db_name + @property def serializer(self) -> Serializer[Json]: return self._conn.serializer diff --git a/arangoasync/typings.py b/arangoasync/typings.py index e9a1317..d0065bf 100644 --- a/arangoasync/typings.py +++ b/arangoasync/typings.py @@ -134,6 +134,9 @@ class JsonWrapper: def __init__(self, data: Json) -> None: self._data = data + for excluded in ("code", "error"): + if excluded in self._data: + self._data.pop(excluded) def __getitem__(self, key: str) -> Any: return self._data[key] @@ -175,11 +178,19 @@ def to_dict(self) -> Json: return self._data def format(self, formatter: Optional[Formatter] = None) -> Json: - """Apply a formatter to the data. Returns the unmodified data by default.""" + """Apply a formatter to the data. + + Returns the unmodified data by default. Should not modify the object in-place. + """ if formatter is not None: return formatter(self._data) return self._data + @staticmethod + def _strip_result(data: Json) -> Json: + """Keep only the `result` key from a dict. Useful when parsing responses.""" + return data["result"] # type: ignore[no-any-return] + class KeyOptions(JsonWrapper): """Additional options for key generation, used on collections. @@ -255,6 +266,31 @@ def validate(self) -> None: '"offset" value is only allowed for "autoincrement" ' "key generator" ) + @staticmethod + def compatibility_formatter(data: Json) -> Json: + """python-arango compatibility formatter.""" + result: Json = {} + if "type" in data: + result["key_generator"] = data["type"] + if "increment" in data: + result["key_increment"] = data["increment"] + if "offset" in data: + result["key_offset"] = data["offset"] + if "allowUserKeys" in data: + result["user_keys"] = data["allowUserKeys"] + if "lastValue" in data: + result["key_last_value"] = data["lastValue"] + return result + + def format(self, formatter: Optional[Formatter] = None) -> Json: + """Apply a formatter to the data. + + By default, the python-arango compatibility formatter is applied. + """ + if formatter is not None: + return super().format(formatter) + return self.compatibility_formatter(self._data) + class CollectionInfo(JsonWrapper): """Collection information. @@ -303,6 +339,17 @@ def col_type(self) -> CollectionType: """Return the type of the collection.""" return CollectionType.from_int(self._data["type"]) + @staticmethod + def compatibility_formatter(data: Json) -> Json: + """python-arango compatibility formatter.""" + return { + "id": data["id"], + "name": data["name"], + "system": data["isSystem"], + "type": str(CollectionType.from_int(data["type"])), + "status": str(CollectionStatus.from_int(data["status"])), + } + def format(self, formatter: Optional[Formatter] = None) -> Json: """Apply a formatter to the data. @@ -310,13 +357,7 @@ def format(self, formatter: Optional[Formatter] = None) -> Json: """ if formatter is not None: return super().format(formatter) - return { - "id": self._data["id"], - "name": self.name, - "system": self.is_system, - "type": str(self.col_type), - "status": str(self.status), - } + return self.compatibility_formatter(self._data) class UserInfo(JsonWrapper): @@ -360,7 +401,7 @@ def __init__( @property def user(self) -> str: - return self._data.get("user") # type: ignore[return-value] + return self._data["user"] # type: ignore[no-any-return] @property def password(self) -> Optional[str]: @@ -368,20 +409,29 @@ def password(self) -> Optional[str]: @property def active(self) -> bool: - return self._data.get("active") # type: ignore[return-value] + return self._data["active"] # type: ignore[no-any-return] @property def extra(self) -> Optional[Json]: return self._data.get("extra") - def to_dict(self) -> Json: - """Return the dictionary.""" - return dict( - user=self.user, - password=self.password, - active=self.active, - extra=self.extra, - ) + @staticmethod + def user_management_formatter(data: Json) -> Json: + """Request formatter.""" + result: Json = dict(user=data["user"]) + if "password" in data: + result["passwd"] = data["password"] + if "active" in data: + result["active"] = data["active"] + if "extra" in data: + result["extra"] = data["extra"] + return result + + def format(self, formatter: Optional[Formatter] = None) -> Json: + """Apply a formatter to the data.""" + if formatter is not None: + return super().format(formatter) + return self._data class ServerStatusInformation(JsonWrapper): @@ -485,3 +535,283 @@ def coordinator(self) -> Optional[Json]: @property def agency(self) -> Optional[Json]: return self._data.get("agency") + + +class DatabaseProperties(JsonWrapper): + """Properties of the database. + + References: + - `get-information-about-the-current-database `__ + """ # noqa: E501 + + def __init__(self, data: Json, strip_result: bool = False) -> None: + super().__init__(self._strip_result(data) if strip_result else data) + + @property + def name(self) -> str: + """The name of the current database.""" + return self._data["name"] # type: ignore[no-any-return] + + @property + def id(self) -> str: + """The id of the current database.""" + return self._data["id"] # type: ignore[no-any-return] + + @property + def path(self) -> Optional[str]: + """The filesystem path of the current database.""" + return self._data.get("path") + + @property + def is_system(self) -> bool: + """Whether the database is the `_system` database.""" + return self._data["isSystem"] # type: ignore[no-any-return] + + @property + def sharding(self) -> Optional[str]: + """The default sharding method for collections.""" + return self._data.get("sharding") + + @property + def replication_factor(self) -> Optional[int]: + """The default replication factor for collections.""" + return self._data.get("replicationFactor") + + @property + def write_concern(self) -> Optional[int]: + """The default write concern for collections.""" + return self._data.get("writeConcern") + + @staticmethod + def compatibility_formatter(data: Json) -> Json: + """python-arango compatibility formatter.""" + result: Json = {} + if "id" in data: + result["id"] = data["id"] + if "name" in data: + result["name"] = data["name"] + if "path" in data: + result["path"] = data["path"] + if "system" in data: + result["system"] = data["system"] + if "isSystem" in data: + result["system"] = data["isSystem"] + if "sharding" in data: + result["sharding"] = data["sharding"] + if "replicationFactor" in data: + result["replication_factor"] = data["replicationFactor"] + if "writeConcern" in data: + result["write_concern"] = data["writeConcern"] + if "replicationVersion" in data: + result["replication_version"] = data["replicationVersion"] + return result + + def format(self, formatter: Optional[Formatter] = None) -> Json: + """Apply a formatter to the data. + + By default, the python-arango compatibility formatter is applied. + """ + if formatter is not None: + return super().format(formatter) + return self.compatibility_formatter(self._data) + + +class CollectionProperties(JsonWrapper): + """Properties of a collection. + + Example: + .. code-block:: json + + { + "writeConcern" : 1, + "waitForSync" : true, + "usesRevisionsAsDocumentIds" : true, + "syncByRevision" : true, + "statusString" : "loaded", + "id" : "68452", + "isSmartChild" : false, + "schema" : null, + "name" : "products", + "type" : 2, + "status" : 3, + "cacheEnabled" : false, + "isSystem" : false, + "internalValidatorType" : 0, + "globallyUniqueId" : "hDA74058C1843/68452", + "keyOptions" : { + "allowUserKeys" : true, + "type" : "traditional", + "lastValue" : 0 + }, + "computedValues" : null, + "objectId" : "68453" + } + + References: + - `get-the-properties-of-a-collection `__ + """ # noqa: E501 + + def __init__(self, data: Json) -> None: + super().__init__(data) + + @property + def write_concern(self) -> Optional[int]: + return self._data.get("writeConcern") + + @property + def wait_for_sync(self) -> Optional[bool]: + return self._data.get("waitForSync") + + @property + def use_revisions_as_document_ids(self) -> Optional[bool]: + return self._data.get("usesRevisionsAsDocumentIds") + + @property + def sync_by_revision(self) -> Optional[bool]: + return self._data.get("syncByRevision") + + @property + def status_string(self) -> Optional[str]: + return self._data.get("statusString") + + @property + def id(self) -> str: + return self._data["id"] # type: ignore[no-any-return] + + @property + def is_smart_child(self) -> bool: + return self._data["isSmartChild"] # type: ignore[no-any-return] + + @property + def schema(self) -> Optional[Json]: + return self._data.get("schema") + + @property + def name(self) -> str: + return self._data["name"] # type: ignore[no-any-return] + + @property + def type(self) -> CollectionType: + return CollectionType.from_int(self._data["type"]) + + @property + def status(self) -> CollectionStatus: + return CollectionStatus.from_int(self._data["status"]) + + @property + def cache_enabled(self) -> Optional[bool]: + return self._data.get("cacheEnabled") + + @property + def is_system(self) -> bool: + return self._data["isSystem"] # type: ignore[no-any-return] + + @property + def internal_validator_type(self) -> Optional[int]: + return self._data.get("internalValidatorType") + + @property + def globally_unique_id(self) -> str: + return self._data["globallyUniqueId"] # type: ignore[no-any-return] + + @property + def key_options(self) -> KeyOptions: + return KeyOptions(self._data["keyOptions"]) + + @property + def computed_values(self) -> Optional[Json]: + return self._data.get("computedValues") + + @property + def object_id(self) -> str: + return self._data["objectId"] # type: ignore[no-any-return] + + @staticmethod + def compatibility_formatter(data: Json) -> Json: + """python-arango compatibility formatter.""" + result: Json = {} + if "id" in data: + result["id"] = data["id"] + if "objectId" in data: + result["object_id"] = data["objectId"] + if "name" in data: + result["name"] = data["name"] + if "isSystem" in data: + result["system"] = data["isSystem"] + if "isSmart" in data: + result["smart"] = data["isSmart"] + if "type" in data: + result["type"] = data["type"] + result["edge"] = data["type"] == 3 + if "waitForSync" in data: + result["sync"] = data["waitForSync"] + if "status" in data: + result["status"] = data["status"] + if "statusString" in data: + result["status_string"] = data["statusString"] + if "globallyUniqueId" in data: + result["global_id"] = data["globallyUniqueId"] + if "cacheEnabled" in data: + result["cache"] = data["cacheEnabled"] + if "replicationFactor" in data: + result["replication_factor"] = data["replicationFactor"] + if "minReplicationFactor" in data: + result["min_replication_factor"] = data["minReplicationFactor"] + if "writeConcern" in data: + result["write_concern"] = data["writeConcern"] + if "shards" in data: + result["shards"] = data["shards"] + if "replicationFactor" in data: + result["replication_factor"] = data["replicationFactor"] + if "numberOfShards" in data: + result["shard_count"] = data["numberOfShards"] + if "shardKeys" in data: + result["shard_fields"] = data["shardKeys"] + if "distributeShardsLike" in data: + result["shard_like"] = data["distributeShardsLike"] + if "shardingStrategy" in data: + result["sharding_strategy"] = data["shardingStrategy"] + if "smartJoinAttribute" in data: + result["smart_join_attribute"] = data["smartJoinAttribute"] + if "keyOptions" in data: + result["key_options"] = KeyOptions.compatibility_formatter( + data["keyOptions"] + ) + if "cid" in data: + result["cid"] = data["cid"] + if "version" in data: + result["version"] = data["version"] + if "allowUserKeys" in data: + result["user_keys"] = data["allowUserKeys"] + if "planId" in data: + result["plan_id"] = data["planId"] + if "deleted" in data: + result["deleted"] = data["deleted"] + if "syncByRevision" in data: + result["sync_by_revision"] = data["syncByRevision"] + if "tempObjectId" in data: + result["temp_object_id"] = data["tempObjectId"] + if "usesRevisionsAsDocumentIds" in data: + result["rev_as_id"] = data["usesRevisionsAsDocumentIds"] + if "isDisjoint" in data: + result["disjoint"] = data["isDisjoint"] + if "isSmartChild" in data: + result["smart_child"] = data["isSmartChild"] + if "minRevision" in data: + result["min_revision"] = data["minRevision"] + if "schema" in data: + result["schema"] = data["schema"] + if data.get("computedValues") is not None: + result["computedValues"] = data["computedValues"] + if "internalValidatorType" in data: + result["internal_validator_type"] = data["internalValidatorType"] + return result + + def format(self, formatter: Optional[Formatter] = None) -> Json: + """Apply a formatter to the data. + + By default, the python-arango compatibility formatter is applied. + """ + if formatter is not None: + return super().format(formatter) + return self.compatibility_formatter(self._data) diff --git a/tests/conftest.py b/tests/conftest.py index 8817abc..e997824 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -200,6 +200,12 @@ async def doc_col(db): await db.delete_collection(col_name) +@pytest.fixture +def bad_col(db): + col_name = generate_col_name() + return db.collection(col_name) + + @pytest_asyncio.fixture(scope="session", autouse=True) async def teardown(): yield diff --git a/tests/test_client.py b/tests/test_client.py index e6e0364..1616f47 100644 --- a/tests/test_client.py +++ b/tests/test_client.py @@ -122,7 +122,11 @@ async def test_client_jwt_auth(url, sys_db_name, basic_auth_root): async def test_client_jwt_superuser_auth(url, sys_db_name, basic_auth_root, token): # successful authentication async with ArangoClient(hosts=url) as client: - await client.db(sys_db_name, auth_method="superuser", token=token, verify=True) + db = await client.db( + sys_db_name, auth_method="superuser", token=token, verify=True + ) + await db.jwt_secrets() + await db.reload_jwt_secrets() # token missing async with ArangoClient(hosts=url) as client: diff --git a/tests/test_collection.py b/tests/test_collection.py new file mode 100644 index 0000000..8a3ac4b --- /dev/null +++ b/tests/test_collection.py @@ -0,0 +1,20 @@ +import pytest + +from arangoasync.exceptions import CollectionPropertiesError + + +def test_collection_attributes(db, doc_col): + assert doc_col.db_name == db.name + assert doc_col.name.startswith("test_collection") + assert repr(doc_col) == f"" + + +@pytest.mark.asyncio +async def test_collection_misc_methods(doc_col, bad_col): + # Properties + properties = await doc_col.properties() + assert properties.name == doc_col.name + assert properties.is_system is False + assert len(properties.format()) > 1 + with pytest.raises(CollectionPropertiesError): + await bad_col.properties() diff --git a/tests/test_database.py b/tests/test_database.py index 0f74f53..7c6e13e 100644 --- a/tests/test_database.py +++ b/tests/test_database.py @@ -1,3 +1,5 @@ +import asyncio + import pytest from arangoasync.collection import StandardCollection @@ -8,15 +10,41 @@ DatabaseCreateError, DatabaseDeleteError, DatabaseListError, + DatabasePropertiesError, + JWTSecretListError, + JWTSecretReloadError, + ServerStatusError, ) from arangoasync.typings import CollectionType, KeyOptions, UserInfo from tests.helpers import generate_col_name, generate_db_name, generate_username @pytest.mark.asyncio -async def test_database_misc_methods(sys_db): +async def test_database_misc_methods(sys_db, db, bad_db): + # Status status = await sys_db.status() assert status["server"] == "arango" + with pytest.raises(ServerStatusError): + await bad_db.status() + + sys_properties, db_properties = await asyncio.gather( + sys_db.properties(), db.properties() + ) + assert sys_properties.is_system is True + assert db_properties.is_system is False + assert sys_properties.name == sys_db.name + assert db_properties.name == db.name + assert db_properties.replication_factor == 3 + assert db_properties.write_concern == 2 + with pytest.raises(DatabasePropertiesError): + await bad_db.properties() + assert len(db_properties.format()) > 1 + + # JWT secrets + with pytest.raises(JWTSecretListError): + await bad_db.jwt_secrets() + with pytest.raises(JWTSecretReloadError): + await bad_db.reload_jwt_secrets() @pytest.mark.asyncio @@ -24,6 +52,7 @@ async def test_create_drop_database( arango_client, sys_db, db, + bad_db, basic_auth_root, password, cluster, @@ -60,12 +89,19 @@ async def test_create_drop_database( dbs = await sys_db.databases() assert db_name in dbs assert "_system" in dbs + dbs = await sys_db.databases_accessible_to_user() + assert db_name in dbs + assert "_system" in dbs + dbs = await db.databases_accessible_to_user() + assert db.name in dbs # Cannot list databases without permission with pytest.raises(DatabaseListError): await db.databases() with pytest.raises(DatabaseListError): await db.has_database(db_name) + with pytest.raises(DatabaseListError): + await bad_db.databases_accessible_to_user() # Databases can only be dropped from the system database with pytest.raises(DatabaseDeleteError): diff --git a/tests/test_user.py b/tests/test_user.py index 6ed66fe..4724927 100644 --- a/tests/test_user.py +++ b/tests/test_user.py @@ -1,8 +1,21 @@ import pytest -from arangoasync.exceptions import UserCreateError, UserDeleteError, UserListError +from arangoasync.auth import Auth +from arangoasync.errno import USER_NOT_FOUND +from arangoasync.exceptions import ( + CollectionCreateError, + DocumentInsertError, + PermissionResetError, + PermissionUpdateError, + UserCreateError, + UserDeleteError, + UserGetError, + UserListError, + UserReplaceError, + UserUpdateError, +) from arangoasync.typings import UserInfo -from tests.helpers import generate_string, generate_username +from tests.helpers import generate_col_name, generate_string, generate_username @pytest.mark.asyncio @@ -44,6 +57,10 @@ async def test_user_management(sys_db, db, bad_db): assert user.user == username assert user.active is True + # Get non-existing user + with pytest.raises(UserGetError): + await sys_db.user(generate_username()) + # Create already existing user with pytest.raises(UserCreateError): await sys_db.create_user( @@ -55,6 +72,48 @@ async def test_user_management(sys_db, db, bad_db): ) ) + # Update existing user + new_user = await sys_db.update_user( + UserInfo( + user=username, + password=password, + active=False, + extra={"bar": "baz"}, + ) + ) + assert new_user["user"] == username + assert new_user["active"] is False + assert new_user["extra"] == {"foo": "bar", "bar": "baz"} + assert await sys_db.user(username) == new_user + + # Update missing user + with pytest.raises(UserUpdateError) as err: + await sys_db.update_user( + UserInfo(user=generate_username(), password=generate_string()) + ) + assert err.value.error_code == USER_NOT_FOUND + + # Replace existing user + new_user = await sys_db.replace_user( + UserInfo( + user=username, + password=password, + active=True, + extra={"baz": "qux"}, + ) + ) + assert new_user["user"] == username + assert new_user["active"] is True + assert new_user["extra"] == {"baz": "qux"} + assert await sys_db.user(username) == new_user + + # Replace missing user + with pytest.raises(UserReplaceError) as err: + await sys_db.replace_user( + {"user": generate_username(), "password": generate_string()} + ) + assert err.value.error_code == USER_NOT_FOUND + # Delete the newly created user assert await sys_db.delete_user(username) is True users = await sys_db.users() @@ -73,3 +132,60 @@ async def test_user_management(sys_db, db, bad_db): await bad_db.users() with pytest.raises(UserListError): await bad_db.has_user(username) + + +@pytest.mark.asyncio +async def test_user_change_permissions(sys_db, arango_client, db): + username = generate_username() + password = generate_string() + auth = Auth(username, password) + + # Set read-only permissions + await sys_db.create_user(UserInfo(username, password)) + + # Should not be able to update permissions without permission + with pytest.raises(PermissionUpdateError): + await db.update_permission(username, "ro", db.name) + + await sys_db.update_permission(username, "ro", db.name) + + # Verify read-only permissions + permission = await sys_db.permission(username, db.name) + assert permission == "ro" + + # Should not be able to create a collection + col_name = generate_col_name() + db2 = await arango_client.db(db.name, auth=auth, verify=True) + with pytest.raises(CollectionCreateError): + await db2.create_collection(col_name) + + all_permissions = await sys_db.permissions(username) + assert "_system" in all_permissions + assert db.name in all_permissions + all_permissions = await sys_db.permissions(username, full=False) + assert all_permissions[db.name] == "ro" + + # Set read-write permissions + await sys_db.update_permission(username, "rw", db.name) + + # Should be able to create collection + col = await db2.create_collection(col_name) + await col.insert({"_key": "test"}) + + # Reset permissions + with pytest.raises(PermissionResetError): + await db.reset_permission(username, db.name) + await sys_db.reset_permission(username, db.name) + with pytest.raises(DocumentInsertError): + await col.insert({"_key": "test"}) + + # Allow rw access + await sys_db.update_permission(username, "rw", db.name) + await col.insert({"_key": "test2"}) + + # No access to collection + await sys_db.update_permission(username, "none", db.name, col_name) + with pytest.raises(DocumentInsertError): + await col.insert({"_key": "test"}) + + await db.delete_collection(col_name) From 0fe44244fe5ae0e3ee9299c6fb7bc6b08a7f5c78 Mon Sep 17 00:00:00 2001 From: Alex Petenchea Date: Tue, 29 Oct 2024 20:36:38 +0200 Subject: [PATCH 2/3] Fixing cluster-only properties --- tests/test_database.py | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/tests/test_database.py b/tests/test_database.py index 7c6e13e..fc8b2bc 100644 --- a/tests/test_database.py +++ b/tests/test_database.py @@ -20,7 +20,7 @@ @pytest.mark.asyncio -async def test_database_misc_methods(sys_db, db, bad_db): +async def test_database_misc_methods(sys_db, db, bad_db, cluster): # Status status = await sys_db.status() assert status["server"] == "arango" @@ -34,8 +34,10 @@ async def test_database_misc_methods(sys_db, db, bad_db): assert db_properties.is_system is False assert sys_properties.name == sys_db.name assert db_properties.name == db.name - assert db_properties.replication_factor == 3 - assert db_properties.write_concern == 2 + if cluster: + assert db_properties.replication_factor == 3 + assert db_properties.write_concern == 2 + with pytest.raises(DatabasePropertiesError): await bad_db.properties() assert len(db_properties.format()) > 1 From 76e6e51f94941e8b364c54a035175bb62bac4922 Mon Sep 17 00:00:00 2001 From: Alex Petenchea Date: Tue, 29 Oct 2024 21:06:07 +0200 Subject: [PATCH 3/3] Fixing enterprise-only features --- arangoasync/typings.py | 48 +++++++++++++++++++++--------------------- tests/test_client.py | 9 +++++--- 2 files changed, 30 insertions(+), 27 deletions(-) diff --git a/arangoasync/typings.py b/arangoasync/typings.py index d0065bf..95f2d65 100644 --- a/arangoasync/typings.py +++ b/arangoasync/typings.py @@ -622,30 +622,30 @@ class CollectionProperties(JsonWrapper): Example: .. code-block:: json - { - "writeConcern" : 1, - "waitForSync" : true, - "usesRevisionsAsDocumentIds" : true, - "syncByRevision" : true, - "statusString" : "loaded", - "id" : "68452", - "isSmartChild" : false, - "schema" : null, - "name" : "products", - "type" : 2, - "status" : 3, - "cacheEnabled" : false, - "isSystem" : false, - "internalValidatorType" : 0, - "globallyUniqueId" : "hDA74058C1843/68452", - "keyOptions" : { - "allowUserKeys" : true, - "type" : "traditional", - "lastValue" : 0 - }, - "computedValues" : null, - "objectId" : "68453" - } + { + "writeConcern" : 1, + "waitForSync" : true, + "usesRevisionsAsDocumentIds" : true, + "syncByRevision" : true, + "statusString" : "loaded", + "id" : "68452", + "isSmartChild" : false, + "schema" : null, + "name" : "products", + "type" : 2, + "status" : 3, + "cacheEnabled" : false, + "isSystem" : false, + "internalValidatorType" : 0, + "globallyUniqueId" : "hDA74058C1843/68452", + "keyOptions" : { + "allowUserKeys" : true, + "type" : "traditional", + "lastValue" : 0 + }, + "computedValues" : null, + "objectId" : "68453" + } References: - `get-the-properties-of-a-collection `__ diff --git a/tests/test_client.py b/tests/test_client.py index 1616f47..718d307 100644 --- a/tests/test_client.py +++ b/tests/test_client.py @@ -119,14 +119,17 @@ async def test_client_jwt_auth(url, sys_db_name, basic_auth_root): @pytest.mark.asyncio -async def test_client_jwt_superuser_auth(url, sys_db_name, basic_auth_root, token): +async def test_client_jwt_superuser_auth( + url, sys_db_name, basic_auth_root, token, enterprise +): # successful authentication async with ArangoClient(hosts=url) as client: db = await client.db( sys_db_name, auth_method="superuser", token=token, verify=True ) - await db.jwt_secrets() - await db.reload_jwt_secrets() + if enterprise: + await db.jwt_secrets() + await db.reload_jwt_secrets() # token missing async with ArangoClient(hosts=url) as client: