diff --git a/CHANGELOG.md b/CHANGELOG.md index f833c9ca..92b011c7 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -3,7 +3,9 @@ See also https://github.com/neo4j/neo4j-python-driver/wiki for a full changelog. ## NEXT RELEASE -- No breaking or major changes. +- Deprecated setting attributes on `Neo4jError` like `message` and `code`. +- Deprecated undocumented method `Neo4jError.hydrate`. + It's internal and should not be used by client code. ## Version 5.25 diff --git a/docs/source/api.rst b/docs/source/api.rst index a623387c..35041a85 100644 --- a/docs/source/api.rst +++ b/docs/source/api.rst @@ -1989,6 +1989,17 @@ Errors ****** +GQL Errors +========== +.. autoexception:: neo4j.exceptions.GqlError() + :show-inheritance: + :members: gql_status, message, gql_status_description, gql_raw_classification, gql_classification, diagnostic_record, __cause__ + +.. autoclass:: neo4j.exceptions.GqlErrorClassification() + :show-inheritance: + :members: + + Neo4j Errors ============ diff --git a/src/neo4j/_async/io/_bolt.py b/src/neo4j/_async/io/_bolt.py index 87c78456..339c065f 100644 --- a/src/neo4j/_async/io/_bolt.py +++ b/src/neo4j/_async/io/_bolt.py @@ -217,6 +217,7 @@ def _to_auth_dict(cls, auth): try: return vars(auth) except (KeyError, TypeError) as e: + # TODO: 6.0 - change this to be a DriverError (or subclass) raise AuthError( f"Cannot determine auth details from {auth!r}" ) from e @@ -306,6 +307,7 @@ def protocol_handlers(cls, protocol_version=None): AsyncBolt5x4, AsyncBolt5x5, AsyncBolt5x6, + AsyncBolt5x7, ) handlers = { @@ -322,6 +324,7 @@ def protocol_handlers(cls, protocol_version=None): AsyncBolt5x4.PROTOCOL_VERSION: AsyncBolt5x4, AsyncBolt5x5.PROTOCOL_VERSION: AsyncBolt5x5, AsyncBolt5x6.PROTOCOL_VERSION: AsyncBolt5x6, + AsyncBolt5x7.PROTOCOL_VERSION: AsyncBolt5x7, } if protocol_version is None: @@ -458,7 +461,10 @@ async def open( # avoid new lines after imports for better readability and conciseness # fmt: off - if protocol_version == (5, 6): + if protocol_version == (5, 7): + from ._bolt5 import AsyncBolt5x7 + bolt_cls = AsyncBolt5x7 + elif protocol_version == (5, 6): from ._bolt5 import AsyncBolt5x6 bolt_cls = AsyncBolt5x6 elif protocol_version == (5, 5): @@ -506,6 +512,7 @@ async def open( await AsyncBoltSocket.close_socket(s) supported_versions = cls.protocol_handlers().keys() + # TODO: 6.0 - raise public DriverError subclass instead raise BoltHandshakeError( "The neo4j server does not support communication with this " "driver. This driver has support for Bolt protocols " @@ -909,6 +916,16 @@ def goodbye(self, dehydration_hooks=None, hydration_hooks=None): def new_hydration_scope(self): return self.hydration_handler.new_hydration_scope() + def _default_hydration_hooks(self, dehydration_hooks, hydration_hooks): + if dehydration_hooks is not None and hydration_hooks is not None: + return dehydration_hooks, hydration_hooks + hydration_scope = self.new_hydration_scope() + if dehydration_hooks is None: + dehydration_hooks = hydration_scope.dehydration_hooks + if hydration_hooks is None: + hydration_hooks = hydration_scope.hydration_hooks + return dehydration_hooks, hydration_hooks + def _append( self, signature, fields=(), response=None, dehydration_hooks=None ): diff --git a/src/neo4j/_async/io/_bolt3.py b/src/neo4j/_async/io/_bolt3.py index 1617f0d7..08e75abb 100644 --- a/src/neo4j/_async/io/_bolt3.py +++ b/src/neo4j/_async/io/_bolt3.py @@ -215,6 +215,9 @@ async def hello(self, dehydration_hooks=None, hydration_hooks=None): or self.notifications_disabled_classifications is not None ): self.assert_notification_filtering_support() + dehydration_hooks, hydration_hooks = self._default_hydration_hooks( + dehydration_hooks, hydration_hooks + ) headers = self.get_base_headers() headers.update(self.auth_dict) logged_headers = dict(headers) @@ -275,6 +278,9 @@ async def route( f"{self.PROTOCOL_VERSION!r}. Trying to impersonate " f"{imp_user!r}." ) + dehydration_hooks, hydration_hooks = self._default_hydration_hooks( + dehydration_hooks, hydration_hooks + ) metadata = {} records = [] @@ -337,6 +343,9 @@ def run( or notifications_disabled_classifications is not None ): self.assert_notification_filtering_support() + dehydration_hooks, hydration_hooks = self._default_hydration_hooks( + dehydration_hooks, hydration_hooks + ) if not parameters: parameters = {} extra = {} @@ -379,6 +388,9 @@ def discard( **handlers, ): # Just ignore n and qid, it is not supported in the Bolt 3 Protocol. + dehydration_hooks, hydration_hooks = self._default_hydration_hooks( + dehydration_hooks, hydration_hooks + ) log.debug("[#%04X] C: DISCARD_ALL", self.local_port) self._append( b"\x2f", @@ -396,6 +408,9 @@ def pull( **handlers, ): # Just ignore n and qid, it is not supported in the Bolt 3 Protocol. + dehydration_hooks, hydration_hooks = self._default_hydration_hooks( + dehydration_hooks, hydration_hooks + ) log.debug("[#%04X] C: PULL_ALL", self.local_port) self._append( b"\x3f", @@ -435,6 +450,9 @@ def begin( or notifications_disabled_classifications is not None ): self.assert_notification_filtering_support() + dehydration_hooks, hydration_hooks = self._default_hydration_hooks( + dehydration_hooks, hydration_hooks + ) extra = {} if mode in {READ_ACCESS, "r"}: # It will default to mode "w" if nothing is specified @@ -464,6 +482,9 @@ def begin( ) def commit(self, dehydration_hooks=None, hydration_hooks=None, **handlers): + dehydration_hooks, hydration_hooks = self._default_hydration_hooks( + dehydration_hooks, hydration_hooks + ) log.debug("[#%04X] C: COMMIT", self.local_port) self._append( b"\x12", @@ -475,6 +496,9 @@ def commit(self, dehydration_hooks=None, hydration_hooks=None, **handlers): def rollback( self, dehydration_hooks=None, hydration_hooks=None, **handlers ): + dehydration_hooks, hydration_hooks = self._default_hydration_hooks( + dehydration_hooks, hydration_hooks + ) log.debug("[#%04X] C: ROLLBACK", self.local_port) self._append( b"\x13", @@ -490,6 +514,9 @@ async def reset(self, dehydration_hooks=None, hydration_hooks=None): Add a RESET message to the outgoing queue, send it and consume all remaining messages. """ + dehydration_hooks, hydration_hooks = self._default_hydration_hooks( + dehydration_hooks, hydration_hooks + ) log.debug("[#%04X] C: RESET", self.local_port) response = ResetResponse(self, "reset", hydration_hooks) self._append( @@ -499,6 +526,9 @@ async def reset(self, dehydration_hooks=None, hydration_hooks=None): await self.fetch_all() def goodbye(self, dehydration_hooks=None, hydration_hooks=None): + dehydration_hooks, hydration_hooks = self._default_hydration_hooks( + dehydration_hooks, hydration_hooks + ) log.debug("[#%04X] C: GOODBYE", self.local_port) self._append(b"\x02", (), dehydration_hooks=dehydration_hooks) diff --git a/src/neo4j/_async/io/_bolt4.py b/src/neo4j/_async/io/_bolt4.py index 9be53e75..202d5570 100644 --- a/src/neo4j/_async/io/_bolt4.py +++ b/src/neo4j/_async/io/_bolt4.py @@ -131,6 +131,9 @@ async def hello(self, dehydration_hooks=None, hydration_hooks=None): or self.notifications_disabled_classifications is not None ): self.assert_notification_filtering_support() + dehydration_hooks, hydration_hooks = self._default_hydration_hooks( + dehydration_hooks, hydration_hooks + ) headers = self.get_base_headers() headers.update(self.auth_dict) logged_headers = dict(headers) @@ -184,6 +187,9 @@ async def route( f"{self.PROTOCOL_VERSION!r}. Trying to impersonate " f"{imp_user!r}." ) + dehydration_hooks, hydration_hooks = self._default_hydration_hooks( + dehydration_hooks, hydration_hooks + ) metadata = {} records = [] @@ -244,6 +250,9 @@ def run( or notifications_disabled_classifications is not None ): self.assert_notification_filtering_support() + dehydration_hooks, hydration_hooks = self._default_hydration_hooks( + dehydration_hooks, hydration_hooks + ) if not parameters: parameters = {} extra = {} @@ -292,6 +301,9 @@ def discard( hydration_hooks=None, **handlers, ): + dehydration_hooks, hydration_hooks = self._default_hydration_hooks( + dehydration_hooks, hydration_hooks + ) extra = {"n": n} if qid != -1: extra["qid"] = qid @@ -311,6 +323,9 @@ def pull( hydration_hooks=None, **handlers, ): + dehydration_hooks, hydration_hooks = self._default_hydration_hooks( + dehydration_hooks, hydration_hooks + ) extra = {"n": n} if qid != -1: extra["qid"] = qid @@ -347,6 +362,9 @@ def begin( or notifications_disabled_classifications is not None ): self.assert_notification_filtering_support() + dehydration_hooks, hydration_hooks = self._default_hydration_hooks( + dehydration_hooks, hydration_hooks + ) extra = {} if mode in {READ_ACCESS, "r"}: # It will default to mode "w" if nothing is specified @@ -379,6 +397,9 @@ def begin( ) def commit(self, dehydration_hooks=None, hydration_hooks=None, **handlers): + dehydration_hooks, hydration_hooks = self._default_hydration_hooks( + dehydration_hooks, hydration_hooks + ) log.debug("[#%04X] C: COMMIT", self.local_port) self._append( b"\x12", @@ -390,6 +411,9 @@ def commit(self, dehydration_hooks=None, hydration_hooks=None, **handlers): def rollback( self, dehydration_hooks=None, hydration_hooks=None, **handlers ): + dehydration_hooks, hydration_hooks = self._default_hydration_hooks( + dehydration_hooks, hydration_hooks + ) log.debug("[#%04X] C: ROLLBACK", self.local_port) self._append( b"\x13", @@ -405,6 +429,9 @@ async def reset(self, dehydration_hooks=None, hydration_hooks=None): Add a RESET message to the outgoing queue, send it and consume all remaining messages. """ + dehydration_hooks, hydration_hooks = self._default_hydration_hooks( + dehydration_hooks, hydration_hooks + ) log.debug("[#%04X] C: RESET", self.local_port) response = ResetResponse(self, "reset", hydration_hooks) self._append( @@ -414,6 +441,9 @@ async def reset(self, dehydration_hooks=None, hydration_hooks=None): await self.fetch_all() def goodbye(self, dehydration_hooks=None, hydration_hooks=None): + dehydration_hooks, hydration_hooks = self._default_hydration_hooks( + dehydration_hooks, hydration_hooks + ) log.debug("[#%04X] C: GOODBYE", self.local_port) self._append(b"\x02", (), dehydration_hooks=dehydration_hooks) @@ -547,6 +577,9 @@ async def route( f"{self.PROTOCOL_VERSION!r}. Trying to impersonate " f"{imp_user!r}." ) + dehydration_hooks, hydration_hooks = self._default_hydration_hooks( + dehydration_hooks, hydration_hooks + ) routing_context = self.routing_context or {} log.debug( @@ -576,6 +609,9 @@ async def hello(self, dehydration_hooks=None, hydration_hooks=None): or self.notifications_disabled_classifications is not None ): self.assert_notification_filtering_support() + dehydration_hooks, hydration_hooks = self._default_hydration_hooks( + dehydration_hooks, hydration_hooks + ) def on_success(metadata): self.configuration_hints.update(metadata.pop("hints", {})) @@ -635,6 +671,9 @@ async def route( dehydration_hooks=None, hydration_hooks=None, ): + dehydration_hooks, hydration_hooks = self._default_hydration_hooks( + dehydration_hooks, hydration_hooks + ) routing_context = self.routing_context or {} db_context = {} if database is not None: @@ -683,6 +722,9 @@ def run( or notifications_disabled_classifications is not None ): self.assert_notification_filtering_support() + dehydration_hooks, hydration_hooks = self._default_hydration_hooks( + dehydration_hooks, hydration_hooks + ) if not parameters: parameters = {} extra = {} @@ -744,6 +786,9 @@ def begin( or notifications_disabled_classifications is not None ): self.assert_notification_filtering_support() + dehydration_hooks, hydration_hooks = self._default_hydration_hooks( + dehydration_hooks, hydration_hooks + ) extra = {} if mode in {READ_ACCESS, "r"}: # It will default to mode "w" if nothing is specified diff --git a/src/neo4j/_async/io/_bolt5.py b/src/neo4j/_async/io/_bolt5.py index 37031971..06336193 100644 --- a/src/neo4j/_async/io/_bolt5.py +++ b/src/neo4j/_async/io/_bolt5.py @@ -136,6 +136,9 @@ async def hello(self, dehydration_hooks=None, hydration_hooks=None): or self.notifications_disabled_classifications is not None ): self.assert_notification_filtering_support() + dehydration_hooks, hydration_hooks = self._default_hydration_hooks( + dehydration_hooks, hydration_hooks + ) def on_success(metadata): self.configuration_hints.update(metadata.pop("hints", {})) @@ -200,6 +203,9 @@ async def route( dehydration_hooks=None, hydration_hooks=None, ): + dehydration_hooks, hydration_hooks = self._default_hydration_hooks( + dehydration_hooks, hydration_hooks + ) routing_context = self.routing_context or {} db_context = {} if database is not None: @@ -221,7 +227,7 @@ async def route( response=Response( self, "route", hydration_hooks, on_success=metadata.update ), - dehydration_hooks=hydration_hooks, + dehydration_hooks=dehydration_hooks, ) await self.send_all() await self.fetch_all() @@ -248,6 +254,9 @@ def run( or notifications_disabled_classifications is not None ): self.assert_notification_filtering_support() + dehydration_hooks, hydration_hooks = self._default_hydration_hooks( + dehydration_hooks, hydration_hooks + ) if not parameters: parameters = {} extra = {} @@ -298,6 +307,9 @@ def discard( hydration_hooks=None, **handlers, ): + dehydration_hooks, hydration_hooks = self._default_hydration_hooks( + dehydration_hooks, hydration_hooks + ) extra = {"n": n} if qid != -1: extra["qid"] = qid @@ -317,6 +329,9 @@ def pull( hydration_hooks=None, **handlers, ): + dehydration_hooks, hydration_hooks = self._default_hydration_hooks( + dehydration_hooks, hydration_hooks + ) extra = {"n": n} if qid != -1: extra["qid"] = qid @@ -347,6 +362,9 @@ def begin( or notifications_disabled_classifications is not None ): self.assert_notification_filtering_support() + dehydration_hooks, hydration_hooks = self._default_hydration_hooks( + dehydration_hooks, hydration_hooks + ) extra = {} if mode in {READ_ACCESS, "r"}: # It will default to mode "w" if nothing is specified @@ -381,6 +399,9 @@ def begin( ) def commit(self, dehydration_hooks=None, hydration_hooks=None, **handlers): + dehydration_hooks, hydration_hooks = self._default_hydration_hooks( + dehydration_hooks, hydration_hooks + ) log.debug("[#%04X] C: COMMIT", self.local_port) self._append( b"\x12", @@ -392,6 +413,9 @@ def commit(self, dehydration_hooks=None, hydration_hooks=None, **handlers): def rollback( self, dehydration_hooks=None, hydration_hooks=None, **handlers ): + dehydration_hooks, hydration_hooks = self._default_hydration_hooks( + dehydration_hooks, hydration_hooks + ) log.debug("[#%04X] C: ROLLBACK", self.local_port) self._append( b"\x13", @@ -407,6 +431,9 @@ async def reset(self, dehydration_hooks=None, hydration_hooks=None): Add a RESET message to the outgoing queue, send it and consume all remaining messages. """ + dehydration_hooks, hydration_hooks = self._default_hydration_hooks( + dehydration_hooks, hydration_hooks + ) log.debug("[#%04X] C: RESET", self.local_port) response = ResetResponse(self, "reset", hydration_hooks) self._append( @@ -416,6 +443,9 @@ async def reset(self, dehydration_hooks=None, hydration_hooks=None): await self.fetch_all() def goodbye(self, dehydration_hooks=None, hydration_hooks=None): + dehydration_hooks, hydration_hooks = self._default_hydration_hooks( + dehydration_hooks, hydration_hooks + ) log.debug("[#%04X] C: GOODBYE", self.local_port) self._append(b"\x02", (), dehydration_hooks=dehydration_hooks) @@ -580,6 +610,9 @@ async def hello(self, dehydration_hooks=None, hydration_hooks=None): or self.notifications_disabled_classifications is not None ): self.assert_notification_filtering_support() + dehydration_hooks, hydration_hooks = self._default_hydration_hooks( + dehydration_hooks, hydration_hooks + ) def on_success(metadata): self.configuration_hints.update(metadata.pop("hints", {})) @@ -620,6 +653,9 @@ def on_success(metadata): check_supported_server_product(self.server_info.agent) def logon(self, dehydration_hooks=None, hydration_hooks=None): + dehydration_hooks, hydration_hooks = self._default_hydration_hooks( + dehydration_hooks, hydration_hooks + ) logged_auth_dict = dict(self.auth_dict) if "credentials" in logged_auth_dict: logged_auth_dict["credentials"] = "*******" @@ -632,6 +668,9 @@ def logon(self, dehydration_hooks=None, hydration_hooks=None): ) def logoff(self, dehydration_hooks=None, hydration_hooks=None): + dehydration_hooks, hydration_hooks = self._default_hydration_hooks( + dehydration_hooks, hydration_hooks + ) log.debug("[#%04X] C: LOGOFF", self.local_port) self._append( b"\x6b", @@ -658,6 +697,10 @@ def get_base_headers(self): return headers async def hello(self, dehydration_hooks=None, hydration_hooks=None): + dehydration_hooks, hydration_hooks = self._default_hydration_hooks( + dehydration_hooks, hydration_hooks + ) + def on_success(metadata): self.configuration_hints.update(metadata.pop("hints", {})) self.server_info.update(metadata) @@ -709,6 +752,9 @@ def run( hydration_hooks=None, **handlers, ): + dehydration_hooks, hydration_hooks = self._default_hydration_hooks( + dehydration_hooks, hydration_hooks + ) if not parameters: parameters = {} extra = {} @@ -773,6 +819,9 @@ def begin( hydration_hooks=None, **handlers, ): + dehydration_hooks, hydration_hooks = self._default_hydration_hooks( + dehydration_hooks, hydration_hooks + ) extra = {} if mode in {READ_ACCESS, "r"}: # It will default to mode "w" if nothing is specified @@ -838,6 +887,9 @@ def telemetry( "telemetry.enabled", False ): return + dehydration_hooks, hydration_hooks = self._default_hydration_hooks( + dehydration_hooks, hydration_hooks + ) api_raw = int(api) log.debug( "[#%04X] C: TELEMETRY %i # (%r)", self.local_port, api_raw, api @@ -877,6 +929,9 @@ def run( hydration_hooks=None, **handlers, ): + dehydration_hooks, hydration_hooks = self._default_hydration_hooks( + dehydration_hooks, hydration_hooks + ) if not parameters: parameters = {} extra = {} @@ -941,6 +996,9 @@ def begin( hydration_hooks=None, **handlers, ): + dehydration_hooks, hydration_hooks = self._default_hydration_hooks( + dehydration_hooks, hydration_hooks + ) extra = {} if mode in {READ_ACCESS, "r"}: # It will default to mode "w" if nothing is specified @@ -982,7 +1040,7 @@ def begin( dehydration_hooks=dehydration_hooks, ) - DEFAULT_DIAGNOSTIC_RECORD = ( + DEFAULT_STATUS_DIAGNOSTIC_RECORD = ( ("OPERATION", ""), ("OPERATION_CODE", "0"), ("CURRENT_SCHEMA", "/"), @@ -1009,7 +1067,7 @@ def enrich(metadata_): diag_record, ) continue - for key, value in self.DEFAULT_DIAGNOSTIC_RECORD: + for key, value in self.DEFAULT_STATUS_DIAGNOSTIC_RECORD: diag_record.setdefault(key, value) enrich(metadata) @@ -1062,15 +1120,108 @@ def enrich(metadata_): if not isinstance(diag_record, dict): log.info( "[#%04X] _: Server supplied an " - "invalid diagnostic record (%r).", + "invalid status diagnostic record (%r).", self.local_port, diag_record, ) continue - for key, value in self.DEFAULT_DIAGNOSTIC_RECORD: + for key, value in self.DEFAULT_STATUS_DIAGNOSTIC_RECORD: diag_record.setdefault(key, value) enrich(metadata) await AsyncUtil.callback(wrapped_handler, metadata) return handler + + +class AsyncBolt5x7(AsyncBolt5x6): + PROTOCOL_VERSION = Version(5, 7) + + DEFAULT_ERROR_DIAGNOSTIC_RECORD = ( + AsyncBolt5x5.DEFAULT_STATUS_DIAGNOSTIC_RECORD + ) + + def _enrich_error_diagnostic_record(self, metadata): + if not isinstance(metadata, dict): + return + diag_record = metadata.setdefault("diagnostic_record", {}) + if not isinstance(diag_record, dict): + log.info( + "[#%04X] _: Server supplied an " + "invalid error diagnostic record (%r).", + self.local_port, + diag_record, + ) + else: + for key, value in self.DEFAULT_ERROR_DIAGNOSTIC_RECORD: + diag_record.setdefault(key, value) + self._enrich_error_diagnostic_record(metadata.get("cause")) + + async def _process_message(self, tag, fields): + """Process at most one message from the server, if available. + + :returns: 2-tuple of number of detail messages and number of summary + messages fetched + """ + details = [] + summary_signature = summary_metadata = None + if tag == b"\x71": # RECORD + details = fields + elif fields: + summary_signature = tag + summary_metadata = fields[0] + else: + summary_signature = tag + + if details: + # Do not log any data + log.debug("[#%04X] S: RECORD * %d", self.local_port, len(details)) + await self.responses[0].on_records(details) + + if summary_signature is None: + return len(details), 0 + + response = self.responses.popleft() + response.complete = True + if summary_signature == b"\x70": + log.debug( + "[#%04X] S: SUCCESS %r", self.local_port, summary_metadata + ) + self._server_state_manager.transition( + response.message, summary_metadata + ) + await response.on_success(summary_metadata or {}) + elif summary_signature == b"\x7e": + log.debug("[#%04X] S: IGNORED", self.local_port) + await response.on_ignored(summary_metadata or {}) + elif summary_signature == b"\x7f": + log.debug( + "[#%04X] S: FAILURE %r", self.local_port, summary_metadata + ) + self._server_state_manager.state = self.bolt_states.FAILED + self._enrich_error_diagnostic_record(summary_metadata) + try: + await response.on_failure(summary_metadata or {}) + except (ServiceUnavailable, DatabaseUnavailable): + if self.pool: + await self.pool.deactivate(address=self.unresolved_address) + raise + except (NotALeader, ForbiddenOnReadOnlyDatabase): + if self.pool: + await self.pool.on_write_failure( + address=self.unresolved_address, + database=self.last_database, + ) + raise + except Neo4jError as e: + if self.pool: + await self.pool.on_neo4j_error(e, self) + raise + else: + sig_int = ord(summary_signature) + raise BoltProtocolError( + f"Unexpected response message with signature {sig_int:02X}", + self.unresolved_address, + ) + + return len(details), 1 diff --git a/src/neo4j/_async/io/_common.py b/src/neo4j/_async/io/_common.py index 2aa55f61..2b733507 100644 --- a/src/neo4j/_async/io/_common.py +++ b/src/neo4j/_async/io/_common.py @@ -21,6 +21,7 @@ from ..._async_compat.util import AsyncUtil from ..._exceptions import SocketDeadlineExceededError +from ...api import Version from ...exceptions import ( Neo4jError, ServiceUnavailable, @@ -29,6 +30,8 @@ ) +GQL_ERROR_AWARE_PROTOCOL = Version(5, 7) + log = logging.getLogger("neo4j.io") @@ -248,7 +251,7 @@ async def on_failure(self, metadata): await AsyncUtil.callback(handler, metadata) handler = self.handlers.get("on_summary") await AsyncUtil.callback(handler) - raise Neo4jError.hydrate(**metadata) + raise self._hydrate_error(metadata) async def on_ignored(self, metadata=None): """Handle an IGNORED message been received.""" @@ -257,6 +260,12 @@ async def on_ignored(self, metadata=None): handler = self.handlers.get("on_summary") await AsyncUtil.callback(handler) + def _hydrate_error(self, metadata): + if self.connection.PROTOCOL_VERSION >= GQL_ERROR_AWARE_PROTOCOL: + return Neo4jError._hydrate_gql(**metadata) + else: + return Neo4jError._hydrate_neo4j(**metadata) + class InitResponse(Response): async def on_failure(self, metadata): @@ -271,7 +280,7 @@ async def on_failure(self, metadata): "message", "Connection initialisation failed due to an unknown error", ) - raise Neo4jError.hydrate(**metadata) + raise self._hydrate_error(metadata) class LogonResponse(InitResponse): @@ -283,7 +292,7 @@ async def on_failure(self, metadata): await AsyncUtil.callback(handler, metadata) handler = self.handlers.get("on_summary") await AsyncUtil.callback(handler) - raise Neo4jError.hydrate(**metadata) + raise self._hydrate_error(metadata) class ResetResponse(Response): diff --git a/src/neo4j/_async/io/_pool.py b/src/neo4j/_async/io/_pool.py index 0653ff71..7a520abe 100644 --- a/src/neo4j/_async/io/_pool.py +++ b/src/neo4j/_async/io/_pool.py @@ -341,6 +341,7 @@ async def health_check(connection_, deadline_): or not await self.cond.wait(timeout) ): log.debug("[#0000] _: acquisition timed out") + # TODO: 6.0 - change this to be a DriverError (or subclass) raise ClientError( "failed to obtain a connection from the pool within " f"{deadline.original_timeout!r}s (timeout)" @@ -1055,8 +1056,10 @@ async def acquire( liveness_check_timeout, ): if access_mode not in {WRITE_ACCESS, READ_ACCESS}: + # TODO: 6.0 - change this to be a ValueError raise ClientError(f"Non valid 'access_mode'; {access_mode}") if not timeout: + # TODO: 6.0 - change this to be a ValueError raise ClientError( f"'timeout' must be a float larger than 0; {timeout}" ) diff --git a/src/neo4j/_async/work/session.py b/src/neo4j/_async/work/session.py index 8d98124e..72f79464 100644 --- a/src/neo4j/_async/work/session.py +++ b/src/neo4j/_async/work/session.py @@ -294,6 +294,7 @@ async def run( raise TypeError("query must be a string or a Query instance") if self._transaction: + # TODO: 6.0 - change this to be a TransactionError raise ClientError( "Explicit Transaction must be handled explicitly" ) diff --git a/src/neo4j/_codec/hydration/v1/hydration_handler.py b/src/neo4j/_codec/hydration/v1/hydration_handler.py index 8b3c9977..f5ce1bbd 100644 --- a/src/neo4j/_codec/hydration/v1/hydration_handler.py +++ b/src/neo4j/_codec/hydration/v1/hydration_handler.py @@ -154,7 +154,6 @@ def hydrate_path(self, nodes, relationships, sequence): class HydrationHandler(HydrationHandlerABC): def __init__(self): super().__init__() - self._created_scope = False self.struct_hydration_functions = { **self.struct_hydration_functions, b"X": spatial.hydrate_point, @@ -201,8 +200,6 @@ def __init__(self): def patch_utc(self): from ..v2 import temporal as temporal_v2 - assert not self._created_scope - del self.struct_hydration_functions[b"F"] del self.struct_hydration_functions[b"f"] self.struct_hydration_functions.update( @@ -226,5 +223,4 @@ def patch_utc(self): ) def new_hydration_scope(self): - self._created_scope = True return HydrationScope(self, _GraphHydrator()) diff --git a/src/neo4j/_sync/io/_bolt.py b/src/neo4j/_sync/io/_bolt.py index 3aa2f020..f1176ba0 100644 --- a/src/neo4j/_sync/io/_bolt.py +++ b/src/neo4j/_sync/io/_bolt.py @@ -217,6 +217,7 @@ def _to_auth_dict(cls, auth): try: return vars(auth) except (KeyError, TypeError) as e: + # TODO: 6.0 - change this to be a DriverError (or subclass) raise AuthError( f"Cannot determine auth details from {auth!r}" ) from e @@ -306,6 +307,7 @@ def protocol_handlers(cls, protocol_version=None): Bolt5x4, Bolt5x5, Bolt5x6, + Bolt5x7, ) handlers = { @@ -322,6 +324,7 @@ def protocol_handlers(cls, protocol_version=None): Bolt5x4.PROTOCOL_VERSION: Bolt5x4, Bolt5x5.PROTOCOL_VERSION: Bolt5x5, Bolt5x6.PROTOCOL_VERSION: Bolt5x6, + Bolt5x7.PROTOCOL_VERSION: Bolt5x7, } if protocol_version is None: @@ -458,7 +461,10 @@ def open( # avoid new lines after imports for better readability and conciseness # fmt: off - if protocol_version == (5, 6): + if protocol_version == (5, 7): + from ._bolt5 import Bolt5x7 + bolt_cls = Bolt5x7 + elif protocol_version == (5, 6): from ._bolt5 import Bolt5x6 bolt_cls = Bolt5x6 elif protocol_version == (5, 5): @@ -506,6 +512,7 @@ def open( BoltSocket.close_socket(s) supported_versions = cls.protocol_handlers().keys() + # TODO: 6.0 - raise public DriverError subclass instead raise BoltHandshakeError( "The neo4j server does not support communication with this " "driver. This driver has support for Bolt protocols " @@ -909,6 +916,16 @@ def goodbye(self, dehydration_hooks=None, hydration_hooks=None): def new_hydration_scope(self): return self.hydration_handler.new_hydration_scope() + def _default_hydration_hooks(self, dehydration_hooks, hydration_hooks): + if dehydration_hooks is not None and hydration_hooks is not None: + return dehydration_hooks, hydration_hooks + hydration_scope = self.new_hydration_scope() + if dehydration_hooks is None: + dehydration_hooks = hydration_scope.dehydration_hooks + if hydration_hooks is None: + hydration_hooks = hydration_scope.hydration_hooks + return dehydration_hooks, hydration_hooks + def _append( self, signature, fields=(), response=None, dehydration_hooks=None ): diff --git a/src/neo4j/_sync/io/_bolt3.py b/src/neo4j/_sync/io/_bolt3.py index 2847781e..e3cfd142 100644 --- a/src/neo4j/_sync/io/_bolt3.py +++ b/src/neo4j/_sync/io/_bolt3.py @@ -215,6 +215,9 @@ def hello(self, dehydration_hooks=None, hydration_hooks=None): or self.notifications_disabled_classifications is not None ): self.assert_notification_filtering_support() + dehydration_hooks, hydration_hooks = self._default_hydration_hooks( + dehydration_hooks, hydration_hooks + ) headers = self.get_base_headers() headers.update(self.auth_dict) logged_headers = dict(headers) @@ -275,6 +278,9 @@ def route( f"{self.PROTOCOL_VERSION!r}. Trying to impersonate " f"{imp_user!r}." ) + dehydration_hooks, hydration_hooks = self._default_hydration_hooks( + dehydration_hooks, hydration_hooks + ) metadata = {} records = [] @@ -337,6 +343,9 @@ def run( or notifications_disabled_classifications is not None ): self.assert_notification_filtering_support() + dehydration_hooks, hydration_hooks = self._default_hydration_hooks( + dehydration_hooks, hydration_hooks + ) if not parameters: parameters = {} extra = {} @@ -379,6 +388,9 @@ def discard( **handlers, ): # Just ignore n and qid, it is not supported in the Bolt 3 Protocol. + dehydration_hooks, hydration_hooks = self._default_hydration_hooks( + dehydration_hooks, hydration_hooks + ) log.debug("[#%04X] C: DISCARD_ALL", self.local_port) self._append( b"\x2f", @@ -396,6 +408,9 @@ def pull( **handlers, ): # Just ignore n and qid, it is not supported in the Bolt 3 Protocol. + dehydration_hooks, hydration_hooks = self._default_hydration_hooks( + dehydration_hooks, hydration_hooks + ) log.debug("[#%04X] C: PULL_ALL", self.local_port) self._append( b"\x3f", @@ -435,6 +450,9 @@ def begin( or notifications_disabled_classifications is not None ): self.assert_notification_filtering_support() + dehydration_hooks, hydration_hooks = self._default_hydration_hooks( + dehydration_hooks, hydration_hooks + ) extra = {} if mode in {READ_ACCESS, "r"}: # It will default to mode "w" if nothing is specified @@ -464,6 +482,9 @@ def begin( ) def commit(self, dehydration_hooks=None, hydration_hooks=None, **handlers): + dehydration_hooks, hydration_hooks = self._default_hydration_hooks( + dehydration_hooks, hydration_hooks + ) log.debug("[#%04X] C: COMMIT", self.local_port) self._append( b"\x12", @@ -475,6 +496,9 @@ def commit(self, dehydration_hooks=None, hydration_hooks=None, **handlers): def rollback( self, dehydration_hooks=None, hydration_hooks=None, **handlers ): + dehydration_hooks, hydration_hooks = self._default_hydration_hooks( + dehydration_hooks, hydration_hooks + ) log.debug("[#%04X] C: ROLLBACK", self.local_port) self._append( b"\x13", @@ -490,6 +514,9 @@ def reset(self, dehydration_hooks=None, hydration_hooks=None): Add a RESET message to the outgoing queue, send it and consume all remaining messages. """ + dehydration_hooks, hydration_hooks = self._default_hydration_hooks( + dehydration_hooks, hydration_hooks + ) log.debug("[#%04X] C: RESET", self.local_port) response = ResetResponse(self, "reset", hydration_hooks) self._append( @@ -499,6 +526,9 @@ def reset(self, dehydration_hooks=None, hydration_hooks=None): self.fetch_all() def goodbye(self, dehydration_hooks=None, hydration_hooks=None): + dehydration_hooks, hydration_hooks = self._default_hydration_hooks( + dehydration_hooks, hydration_hooks + ) log.debug("[#%04X] C: GOODBYE", self.local_port) self._append(b"\x02", (), dehydration_hooks=dehydration_hooks) diff --git a/src/neo4j/_sync/io/_bolt4.py b/src/neo4j/_sync/io/_bolt4.py index 023d16d8..69bb6dd6 100644 --- a/src/neo4j/_sync/io/_bolt4.py +++ b/src/neo4j/_sync/io/_bolt4.py @@ -131,6 +131,9 @@ def hello(self, dehydration_hooks=None, hydration_hooks=None): or self.notifications_disabled_classifications is not None ): self.assert_notification_filtering_support() + dehydration_hooks, hydration_hooks = self._default_hydration_hooks( + dehydration_hooks, hydration_hooks + ) headers = self.get_base_headers() headers.update(self.auth_dict) logged_headers = dict(headers) @@ -184,6 +187,9 @@ def route( f"{self.PROTOCOL_VERSION!r}. Trying to impersonate " f"{imp_user!r}." ) + dehydration_hooks, hydration_hooks = self._default_hydration_hooks( + dehydration_hooks, hydration_hooks + ) metadata = {} records = [] @@ -244,6 +250,9 @@ def run( or notifications_disabled_classifications is not None ): self.assert_notification_filtering_support() + dehydration_hooks, hydration_hooks = self._default_hydration_hooks( + dehydration_hooks, hydration_hooks + ) if not parameters: parameters = {} extra = {} @@ -292,6 +301,9 @@ def discard( hydration_hooks=None, **handlers, ): + dehydration_hooks, hydration_hooks = self._default_hydration_hooks( + dehydration_hooks, hydration_hooks + ) extra = {"n": n} if qid != -1: extra["qid"] = qid @@ -311,6 +323,9 @@ def pull( hydration_hooks=None, **handlers, ): + dehydration_hooks, hydration_hooks = self._default_hydration_hooks( + dehydration_hooks, hydration_hooks + ) extra = {"n": n} if qid != -1: extra["qid"] = qid @@ -347,6 +362,9 @@ def begin( or notifications_disabled_classifications is not None ): self.assert_notification_filtering_support() + dehydration_hooks, hydration_hooks = self._default_hydration_hooks( + dehydration_hooks, hydration_hooks + ) extra = {} if mode in {READ_ACCESS, "r"}: # It will default to mode "w" if nothing is specified @@ -379,6 +397,9 @@ def begin( ) def commit(self, dehydration_hooks=None, hydration_hooks=None, **handlers): + dehydration_hooks, hydration_hooks = self._default_hydration_hooks( + dehydration_hooks, hydration_hooks + ) log.debug("[#%04X] C: COMMIT", self.local_port) self._append( b"\x12", @@ -390,6 +411,9 @@ def commit(self, dehydration_hooks=None, hydration_hooks=None, **handlers): def rollback( self, dehydration_hooks=None, hydration_hooks=None, **handlers ): + dehydration_hooks, hydration_hooks = self._default_hydration_hooks( + dehydration_hooks, hydration_hooks + ) log.debug("[#%04X] C: ROLLBACK", self.local_port) self._append( b"\x13", @@ -405,6 +429,9 @@ def reset(self, dehydration_hooks=None, hydration_hooks=None): Add a RESET message to the outgoing queue, send it and consume all remaining messages. """ + dehydration_hooks, hydration_hooks = self._default_hydration_hooks( + dehydration_hooks, hydration_hooks + ) log.debug("[#%04X] C: RESET", self.local_port) response = ResetResponse(self, "reset", hydration_hooks) self._append( @@ -414,6 +441,9 @@ def reset(self, dehydration_hooks=None, hydration_hooks=None): self.fetch_all() def goodbye(self, dehydration_hooks=None, hydration_hooks=None): + dehydration_hooks, hydration_hooks = self._default_hydration_hooks( + dehydration_hooks, hydration_hooks + ) log.debug("[#%04X] C: GOODBYE", self.local_port) self._append(b"\x02", (), dehydration_hooks=dehydration_hooks) @@ -547,6 +577,9 @@ def route( f"{self.PROTOCOL_VERSION!r}. Trying to impersonate " f"{imp_user!r}." ) + dehydration_hooks, hydration_hooks = self._default_hydration_hooks( + dehydration_hooks, hydration_hooks + ) routing_context = self.routing_context or {} log.debug( @@ -576,6 +609,9 @@ def hello(self, dehydration_hooks=None, hydration_hooks=None): or self.notifications_disabled_classifications is not None ): self.assert_notification_filtering_support() + dehydration_hooks, hydration_hooks = self._default_hydration_hooks( + dehydration_hooks, hydration_hooks + ) def on_success(metadata): self.configuration_hints.update(metadata.pop("hints", {})) @@ -635,6 +671,9 @@ def route( dehydration_hooks=None, hydration_hooks=None, ): + dehydration_hooks, hydration_hooks = self._default_hydration_hooks( + dehydration_hooks, hydration_hooks + ) routing_context = self.routing_context or {} db_context = {} if database is not None: @@ -683,6 +722,9 @@ def run( or notifications_disabled_classifications is not None ): self.assert_notification_filtering_support() + dehydration_hooks, hydration_hooks = self._default_hydration_hooks( + dehydration_hooks, hydration_hooks + ) if not parameters: parameters = {} extra = {} @@ -744,6 +786,9 @@ def begin( or notifications_disabled_classifications is not None ): self.assert_notification_filtering_support() + dehydration_hooks, hydration_hooks = self._default_hydration_hooks( + dehydration_hooks, hydration_hooks + ) extra = {} if mode in {READ_ACCESS, "r"}: # It will default to mode "w" if nothing is specified diff --git a/src/neo4j/_sync/io/_bolt5.py b/src/neo4j/_sync/io/_bolt5.py index 7692d177..4138a9d5 100644 --- a/src/neo4j/_sync/io/_bolt5.py +++ b/src/neo4j/_sync/io/_bolt5.py @@ -136,6 +136,9 @@ def hello(self, dehydration_hooks=None, hydration_hooks=None): or self.notifications_disabled_classifications is not None ): self.assert_notification_filtering_support() + dehydration_hooks, hydration_hooks = self._default_hydration_hooks( + dehydration_hooks, hydration_hooks + ) def on_success(metadata): self.configuration_hints.update(metadata.pop("hints", {})) @@ -200,6 +203,9 @@ def route( dehydration_hooks=None, hydration_hooks=None, ): + dehydration_hooks, hydration_hooks = self._default_hydration_hooks( + dehydration_hooks, hydration_hooks + ) routing_context = self.routing_context or {} db_context = {} if database is not None: @@ -221,7 +227,7 @@ def route( response=Response( self, "route", hydration_hooks, on_success=metadata.update ), - dehydration_hooks=hydration_hooks, + dehydration_hooks=dehydration_hooks, ) self.send_all() self.fetch_all() @@ -248,6 +254,9 @@ def run( or notifications_disabled_classifications is not None ): self.assert_notification_filtering_support() + dehydration_hooks, hydration_hooks = self._default_hydration_hooks( + dehydration_hooks, hydration_hooks + ) if not parameters: parameters = {} extra = {} @@ -298,6 +307,9 @@ def discard( hydration_hooks=None, **handlers, ): + dehydration_hooks, hydration_hooks = self._default_hydration_hooks( + dehydration_hooks, hydration_hooks + ) extra = {"n": n} if qid != -1: extra["qid"] = qid @@ -317,6 +329,9 @@ def pull( hydration_hooks=None, **handlers, ): + dehydration_hooks, hydration_hooks = self._default_hydration_hooks( + dehydration_hooks, hydration_hooks + ) extra = {"n": n} if qid != -1: extra["qid"] = qid @@ -347,6 +362,9 @@ def begin( or notifications_disabled_classifications is not None ): self.assert_notification_filtering_support() + dehydration_hooks, hydration_hooks = self._default_hydration_hooks( + dehydration_hooks, hydration_hooks + ) extra = {} if mode in {READ_ACCESS, "r"}: # It will default to mode "w" if nothing is specified @@ -381,6 +399,9 @@ def begin( ) def commit(self, dehydration_hooks=None, hydration_hooks=None, **handlers): + dehydration_hooks, hydration_hooks = self._default_hydration_hooks( + dehydration_hooks, hydration_hooks + ) log.debug("[#%04X] C: COMMIT", self.local_port) self._append( b"\x12", @@ -392,6 +413,9 @@ def commit(self, dehydration_hooks=None, hydration_hooks=None, **handlers): def rollback( self, dehydration_hooks=None, hydration_hooks=None, **handlers ): + dehydration_hooks, hydration_hooks = self._default_hydration_hooks( + dehydration_hooks, hydration_hooks + ) log.debug("[#%04X] C: ROLLBACK", self.local_port) self._append( b"\x13", @@ -407,6 +431,9 @@ def reset(self, dehydration_hooks=None, hydration_hooks=None): Add a RESET message to the outgoing queue, send it and consume all remaining messages. """ + dehydration_hooks, hydration_hooks = self._default_hydration_hooks( + dehydration_hooks, hydration_hooks + ) log.debug("[#%04X] C: RESET", self.local_port) response = ResetResponse(self, "reset", hydration_hooks) self._append( @@ -416,6 +443,9 @@ def reset(self, dehydration_hooks=None, hydration_hooks=None): self.fetch_all() def goodbye(self, dehydration_hooks=None, hydration_hooks=None): + dehydration_hooks, hydration_hooks = self._default_hydration_hooks( + dehydration_hooks, hydration_hooks + ) log.debug("[#%04X] C: GOODBYE", self.local_port) self._append(b"\x02", (), dehydration_hooks=dehydration_hooks) @@ -580,6 +610,9 @@ def hello(self, dehydration_hooks=None, hydration_hooks=None): or self.notifications_disabled_classifications is not None ): self.assert_notification_filtering_support() + dehydration_hooks, hydration_hooks = self._default_hydration_hooks( + dehydration_hooks, hydration_hooks + ) def on_success(metadata): self.configuration_hints.update(metadata.pop("hints", {})) @@ -620,6 +653,9 @@ def on_success(metadata): check_supported_server_product(self.server_info.agent) def logon(self, dehydration_hooks=None, hydration_hooks=None): + dehydration_hooks, hydration_hooks = self._default_hydration_hooks( + dehydration_hooks, hydration_hooks + ) logged_auth_dict = dict(self.auth_dict) if "credentials" in logged_auth_dict: logged_auth_dict["credentials"] = "*******" @@ -632,6 +668,9 @@ def logon(self, dehydration_hooks=None, hydration_hooks=None): ) def logoff(self, dehydration_hooks=None, hydration_hooks=None): + dehydration_hooks, hydration_hooks = self._default_hydration_hooks( + dehydration_hooks, hydration_hooks + ) log.debug("[#%04X] C: LOGOFF", self.local_port) self._append( b"\x6b", @@ -658,6 +697,10 @@ def get_base_headers(self): return headers def hello(self, dehydration_hooks=None, hydration_hooks=None): + dehydration_hooks, hydration_hooks = self._default_hydration_hooks( + dehydration_hooks, hydration_hooks + ) + def on_success(metadata): self.configuration_hints.update(metadata.pop("hints", {})) self.server_info.update(metadata) @@ -709,6 +752,9 @@ def run( hydration_hooks=None, **handlers, ): + dehydration_hooks, hydration_hooks = self._default_hydration_hooks( + dehydration_hooks, hydration_hooks + ) if not parameters: parameters = {} extra = {} @@ -773,6 +819,9 @@ def begin( hydration_hooks=None, **handlers, ): + dehydration_hooks, hydration_hooks = self._default_hydration_hooks( + dehydration_hooks, hydration_hooks + ) extra = {} if mode in {READ_ACCESS, "r"}: # It will default to mode "w" if nothing is specified @@ -838,6 +887,9 @@ def telemetry( "telemetry.enabled", False ): return + dehydration_hooks, hydration_hooks = self._default_hydration_hooks( + dehydration_hooks, hydration_hooks + ) api_raw = int(api) log.debug( "[#%04X] C: TELEMETRY %i # (%r)", self.local_port, api_raw, api @@ -877,6 +929,9 @@ def run( hydration_hooks=None, **handlers, ): + dehydration_hooks, hydration_hooks = self._default_hydration_hooks( + dehydration_hooks, hydration_hooks + ) if not parameters: parameters = {} extra = {} @@ -941,6 +996,9 @@ def begin( hydration_hooks=None, **handlers, ): + dehydration_hooks, hydration_hooks = self._default_hydration_hooks( + dehydration_hooks, hydration_hooks + ) extra = {} if mode in {READ_ACCESS, "r"}: # It will default to mode "w" if nothing is specified @@ -982,7 +1040,7 @@ def begin( dehydration_hooks=dehydration_hooks, ) - DEFAULT_DIAGNOSTIC_RECORD = ( + DEFAULT_STATUS_DIAGNOSTIC_RECORD = ( ("OPERATION", ""), ("OPERATION_CODE", "0"), ("CURRENT_SCHEMA", "/"), @@ -1009,7 +1067,7 @@ def enrich(metadata_): diag_record, ) continue - for key, value in self.DEFAULT_DIAGNOSTIC_RECORD: + for key, value in self.DEFAULT_STATUS_DIAGNOSTIC_RECORD: diag_record.setdefault(key, value) enrich(metadata) @@ -1062,15 +1120,108 @@ def enrich(metadata_): if not isinstance(diag_record, dict): log.info( "[#%04X] _: Server supplied an " - "invalid diagnostic record (%r).", + "invalid status diagnostic record (%r).", self.local_port, diag_record, ) continue - for key, value in self.DEFAULT_DIAGNOSTIC_RECORD: + for key, value in self.DEFAULT_STATUS_DIAGNOSTIC_RECORD: diag_record.setdefault(key, value) enrich(metadata) Util.callback(wrapped_handler, metadata) return handler + + +class Bolt5x7(Bolt5x6): + PROTOCOL_VERSION = Version(5, 7) + + DEFAULT_ERROR_DIAGNOSTIC_RECORD = ( + Bolt5x5.DEFAULT_STATUS_DIAGNOSTIC_RECORD + ) + + def _enrich_error_diagnostic_record(self, metadata): + if not isinstance(metadata, dict): + return + diag_record = metadata.setdefault("diagnostic_record", {}) + if not isinstance(diag_record, dict): + log.info( + "[#%04X] _: Server supplied an " + "invalid error diagnostic record (%r).", + self.local_port, + diag_record, + ) + else: + for key, value in self.DEFAULT_ERROR_DIAGNOSTIC_RECORD: + diag_record.setdefault(key, value) + self._enrich_error_diagnostic_record(metadata.get("cause")) + + def _process_message(self, tag, fields): + """Process at most one message from the server, if available. + + :returns: 2-tuple of number of detail messages and number of summary + messages fetched + """ + details = [] + summary_signature = summary_metadata = None + if tag == b"\x71": # RECORD + details = fields + elif fields: + summary_signature = tag + summary_metadata = fields[0] + else: + summary_signature = tag + + if details: + # Do not log any data + log.debug("[#%04X] S: RECORD * %d", self.local_port, len(details)) + self.responses[0].on_records(details) + + if summary_signature is None: + return len(details), 0 + + response = self.responses.popleft() + response.complete = True + if summary_signature == b"\x70": + log.debug( + "[#%04X] S: SUCCESS %r", self.local_port, summary_metadata + ) + self._server_state_manager.transition( + response.message, summary_metadata + ) + response.on_success(summary_metadata or {}) + elif summary_signature == b"\x7e": + log.debug("[#%04X] S: IGNORED", self.local_port) + response.on_ignored(summary_metadata or {}) + elif summary_signature == b"\x7f": + log.debug( + "[#%04X] S: FAILURE %r", self.local_port, summary_metadata + ) + self._server_state_manager.state = self.bolt_states.FAILED + self._enrich_error_diagnostic_record(summary_metadata) + try: + response.on_failure(summary_metadata or {}) + except (ServiceUnavailable, DatabaseUnavailable): + if self.pool: + self.pool.deactivate(address=self.unresolved_address) + raise + except (NotALeader, ForbiddenOnReadOnlyDatabase): + if self.pool: + self.pool.on_write_failure( + address=self.unresolved_address, + database=self.last_database, + ) + raise + except Neo4jError as e: + if self.pool: + self.pool.on_neo4j_error(e, self) + raise + else: + sig_int = ord(summary_signature) + raise BoltProtocolError( + f"Unexpected response message with signature {sig_int:02X}", + self.unresolved_address, + ) + + return len(details), 1 diff --git a/src/neo4j/_sync/io/_common.py b/src/neo4j/_sync/io/_common.py index 09fba0d8..a4870175 100644 --- a/src/neo4j/_sync/io/_common.py +++ b/src/neo4j/_sync/io/_common.py @@ -21,6 +21,7 @@ from ..._async_compat.util import Util from ..._exceptions import SocketDeadlineExceededError +from ...api import Version from ...exceptions import ( Neo4jError, ServiceUnavailable, @@ -29,6 +30,8 @@ ) +GQL_ERROR_AWARE_PROTOCOL = Version(5, 7) + log = logging.getLogger("neo4j.io") @@ -248,7 +251,7 @@ def on_failure(self, metadata): Util.callback(handler, metadata) handler = self.handlers.get("on_summary") Util.callback(handler) - raise Neo4jError.hydrate(**metadata) + raise self._hydrate_error(metadata) def on_ignored(self, metadata=None): """Handle an IGNORED message been received.""" @@ -257,6 +260,12 @@ def on_ignored(self, metadata=None): handler = self.handlers.get("on_summary") Util.callback(handler) + def _hydrate_error(self, metadata): + if self.connection.PROTOCOL_VERSION >= GQL_ERROR_AWARE_PROTOCOL: + return Neo4jError._hydrate_gql(**metadata) + else: + return Neo4jError._hydrate_neo4j(**metadata) + class InitResponse(Response): def on_failure(self, metadata): @@ -271,7 +280,7 @@ def on_failure(self, metadata): "message", "Connection initialisation failed due to an unknown error", ) - raise Neo4jError.hydrate(**metadata) + raise self._hydrate_error(metadata) class LogonResponse(InitResponse): @@ -283,7 +292,7 @@ def on_failure(self, metadata): Util.callback(handler, metadata) handler = self.handlers.get("on_summary") Util.callback(handler) - raise Neo4jError.hydrate(**metadata) + raise self._hydrate_error(metadata) class ResetResponse(Response): diff --git a/src/neo4j/_sync/io/_pool.py b/src/neo4j/_sync/io/_pool.py index 94fd1d06..1570e745 100644 --- a/src/neo4j/_sync/io/_pool.py +++ b/src/neo4j/_sync/io/_pool.py @@ -338,6 +338,7 @@ def health_check(connection_, deadline_): or not self.cond.wait(timeout) ): log.debug("[#0000] _: acquisition timed out") + # TODO: 6.0 - change this to be a DriverError (or subclass) raise ClientError( "failed to obtain a connection from the pool within " f"{deadline.original_timeout!r}s (timeout)" @@ -1052,8 +1053,10 @@ def acquire( liveness_check_timeout, ): if access_mode not in {WRITE_ACCESS, READ_ACCESS}: + # TODO: 6.0 - change this to be a ValueError raise ClientError(f"Non valid 'access_mode'; {access_mode}") if not timeout: + # TODO: 6.0 - change this to be a ValueError raise ClientError( f"'timeout' must be a float larger than 0; {timeout}" ) diff --git a/src/neo4j/_sync/work/session.py b/src/neo4j/_sync/work/session.py index 2ae099a9..ee23b792 100644 --- a/src/neo4j/_sync/work/session.py +++ b/src/neo4j/_sync/work/session.py @@ -294,6 +294,7 @@ def run( raise TypeError("query must be a string or a Query instance") if self._transaction: + # TODO: 6.0 - change this to be a TransactionError raise ClientError( "Explicit Transaction must be handled explicitly" ) diff --git a/src/neo4j/_work/summary.py b/src/neo4j/_work/summary.py index f605e3c7..54849832 100644 --- a/src/neo4j/_work/summary.py +++ b/src/neo4j/_work/summary.py @@ -767,7 +767,7 @@ def gql_status(self) -> str: .. note:: This means these codes are not guaranteed to be stable and may - change in future versions. + change in future versions of the driver or the server. """ if hasattr(self, "_gql_status"): return self._gql_status diff --git a/src/neo4j/exceptions.py b/src/neo4j/exceptions.py index a317ea09..07a23f70 100644 --- a/src/neo4j/exceptions.py +++ b/src/neo4j/exceptions.py @@ -60,11 +60,56 @@ from __future__ import annotations import typing as t +from copy import deepcopy as _deepcopy +from enum import Enum as _Enum -from ._meta import deprecated +from ._meta import ( + deprecated, + preview as _preview, +) + + +__all__ = [ + "AuthConfigurationError", + "AuthError", + "BrokenRecordError", + "CertificateConfigurationError", + "ClientError", + "ConfigurationError", + "ConstraintError", + "CypherSyntaxError", + "CypherTypeError", + "DatabaseError", + "DatabaseUnavailable", + "DriverError", + "Forbidden", + "ForbiddenOnReadOnlyDatabase", + "GqlError", + "GqlErrorClassification", + "IncompleteCommit", + "Neo4jError", + "NotALeader", + "ReadServiceUnavailable", + "ResultConsumedError", + "ResultError", + "ResultFailedError", + "ResultNotSingleError", + "RoutingServiceUnavailable", + "ServiceUnavailable", + "SessionError", + "SessionExpired", + "TokenExpired", + "TransactionError", + "TransactionNestingError", + "TransientError", + "UnsupportedServerProduct", + "WriteServiceUnavailable", +] if t.TYPE_CHECKING: + from collections.abc import Mapping + import typing_extensions as te from ._async.work import ( @@ -88,6 +133,7 @@ ] _TResult = t.Union[AsyncResult, Result] _TSession = t.Union[AsyncSession, Session] + _T = t.TypeVar("_T") else: _TTransaction = t.Union[ "AsyncManagedTransaction", @@ -168,61 +214,458 @@ } +_UNKNOWN_NEO4J_CODE: te.Final[str] = "Neo.DatabaseError.General.UnknownError" +# TODO: 6.0 - Make _UNKNOWN_GQL_MESSAGE the default message +_UNKNOWN_MESSAGE: te.Final[str] = "An unknown error occurred" +_UNKNOWN_GQL_STATUS: te.Final[str] = "50N42" +_UNKNOWN_GQL_DESCRIPTION: te.Final[str] = ( + "error: general processing exception - unexpected error" +) +_UNKNOWN_GQL_MESSAGE: te.Final[str] = ( + f"{_UNKNOWN_GQL_STATUS}: " + "Unexpected error has occurred. See debug log for details." +) +_UNKNOWN_GQL_DIAGNOSTIC_RECORD: te.Final[tuple[tuple[str, t.Any], ...]] = ( + ("OPERATION", ""), + ("OPERATION_CODE", "0"), + ("CURRENT_SCHEMA", "/"), +) + + +class GqlErrorClassification(str, _Enum): + """ + Server-side GQL error category. + + Inherits from :class:`str` and :class:`enum.Enum`. + Hence, can also be compared to its string value:: + + >>> GqlErrorClassification.CLIENT_ERROR == "CLIENT_ERROR" + True + >>> GqlErrorClassification.DATABASE_ERROR == "DATABASE_ERROR" + True + >>> GqlErrorClassification.TRANSIENT_ERROR == "TRANSIENT_ERROR" + True + + **This is a preview**. + It might be changed without following the deprecation policy. + See also + https://github.com/neo4j/neo4j-python-driver/wiki/preview-features + + .. seealso:: :attr:`.GqlError.gql_classification` + + .. versionadded:: 5.26 + """ + + CLIENT_ERROR = "CLIENT_ERROR" + DATABASE_ERROR = "DATABASE_ERROR" + TRANSIENT_ERROR = "TRANSIENT_ERROR" + #: Used when the server provides a Classification which the driver is + #: unaware of. + #: This can happen when connecting to a server newer than the driver or + #: before GQL errors were introduced. + UNKNOWN = "UNKNOWN" + + +class GqlError(Exception): + """ + The GQL compliant data of an error. + + This error isn't raised by the driver as is. + Instead, only subclasses are raised. + Further, it is used as the :attr:`__cause__` of GqlError subclasses. + + **This is a preview**. + It might be changed without following the deprecation policy. + See also + https://github.com/neo4j/neo4j-python-driver/wiki/preview-features + + .. versionadded: 5.26 + """ + + _gql_status: str + # TODO: 6.0 - make message always str + _message: str | None + _gql_status_description: str + _gql_raw_classification: str | None + _gql_classification: GqlErrorClassification + _status_diagnostic_record: dict[str, t.Any] # original, internal only + _diagnostic_record: dict[str, t.Any] # copy to be used externally + _gql_cause: GqlError | None + + @staticmethod + def _hydrate_cause(**metadata: t.Any) -> GqlError: + meta_extractor = _MetaExtractor(metadata) + gql_status = meta_extractor.str_value("gql_status") + description = meta_extractor.str_value("description") + message = meta_extractor.str_value("message") + diagnostic_record = meta_extractor.map_value("diagnostic_record") + cause_map = meta_extractor.map_value("cause") + if cause_map is not None: + cause = GqlError._hydrate_cause(**cause_map) + else: + cause = None + inst = GqlError() + inst._init_gql( + gql_status=gql_status, + message=message, + description=description, + diagnostic_record=diagnostic_record, + cause=cause, + ) + return inst + + def _init_gql( + self, + *, + gql_status: str | None = None, + message: str | None = None, + description: str | None = None, + diagnostic_record: dict[str, t.Any] | None = None, + cause: GqlError | None = None, + ) -> None: + if gql_status is None or message is None or description is None: + self._gql_status = _UNKNOWN_GQL_STATUS + self._message = _UNKNOWN_GQL_MESSAGE + self._gql_status_description = _UNKNOWN_GQL_DESCRIPTION + else: + self._gql_status = gql_status + self._message = message + self._gql_status_description = description + if diagnostic_record is not None: + self._status_diagnostic_record = diagnostic_record + self._gql_cause = cause + + def _set_unknown_gql(self): + self._gql_status = _UNKNOWN_GQL_STATUS + self._message = _UNKNOWN_GQL_MESSAGE + self._gql_status_description = _UNKNOWN_GQL_DESCRIPTION + + def __getattribute__(self, item): + if item != "__cause__": + return super().__getattribute__(item) + gql_cause = self._get_attr_or_none("_gql_cause") + if gql_cause is None: + # No GQL cause, no magic needed + return super().__getattribute__(item) + local_cause = self._get_attr_or_none("__cause__") + if local_cause is None: + # We have a GQL cause but no local cause + # => set the GQL cause as the local cause + self.__cause__ = gql_cause + self.__suppress_context__ = True + self._gql_cause = None + return super().__getattribute__(item) + # We have both a GQL cause and a local cause + # => traverse the cause chain and append the local cause. + root = gql_cause + seen_errors = {id(self), id(root)} + while True: + cause = getattr(root, "__cause__", None) + if cause is None: + root.__cause__ = local_cause + root.__suppress_context__ = True + self.__cause__ = gql_cause + self.__suppress_context__ = True + self._gql_cause = None + return gql_cause + root = cause + if id(root) in seen_errors: + # Circular cause chain -> we have no choice but to either + # overwrite the cause or ignore the new one. + return local_cause + seen_errors.add(id(root)) + + def _get_attr_or_none(self, item): + try: + return super().__getattribute__(item) + except AttributeError: + return None + + @property + def _gql_status_no_preview(self) -> str: + if hasattr(self, "_gql_status"): + return self._gql_status + + self._set_unknown_gql() + return self._gql_status + + @property + @_preview("GQLSTATUS support is a preview feature.") + def gql_status(self) -> str: + """ + The GQLSTATUS returned from the server. + + The status code ``50N42`` (unknown error) is a special code that the + driver will use for polyfilling (when connected to an old, + non-GQL-aware server). + Further, it may be used by servers during the transition-phase to + GQLSTATUS-awareness. + + .. note:: + This means that the code ``50N42`` is not guaranteed to be stable + and may change in future versions of the driver or the server. + """ + return self._gql_status_no_preview + + @property + def _message_no_preview(self) -> str | None: + if hasattr(self, "_message"): + return self._message + + self._set_unknown_gql() + return self._message + + @property + @_preview("GQLSTATUS support is a preview feature.") + def message(self) -> str | None: + """ + The error message returned by the server. + + It is a string representation of the error that occurred. + + This message is meant for human consumption and debugging purposes. + Don't rely on it in a programmatic way. + + This value is never :data:`None` unless the subclass in question + states otherwise. + """ + return self._message_no_preview + + @property + def _gql_status_description_no_preview(self) -> str: + if hasattr(self, "_gql_status_description"): + return self._gql_status_description + + self._set_unknown_gql() + return self._gql_status_description + + @property + @_preview("GQLSTATUS support is a preview feature.") + def gql_status_description(self) -> str: + """ + A description of the GQLSTATUS returned from the server. + + It describes the error that occurred in detail. + + This description is meant for human consumption and debugging purposes. + Don't rely on it in a programmatic way. + """ + return self._gql_status_description_no_preview + + @property + def _gql_raw_classification_no_preview(self) -> str | None: + if hasattr(self, "_gql_raw_classification"): + return self._gql_raw_classification + + diag_record = self._get_status_diagnostic_record() + classification = diag_record.get("_classification") + if not isinstance(classification, str): + self._gql_raw_classification = None + else: + self._gql_raw_classification = classification + return self._gql_raw_classification + + @property + @_preview("GQLSTATUS support is a preview feature.") + def gql_raw_classification(self) -> str | None: + """ + Vendor specific classification of the error. + + This is a convenience accessor for ``_classification`` in the + diagnostic record. + :data:`None` is returned if the classification is not available + or not a string. + """ + return self._gql_raw_classification_no_preview + + @property + def _gql_classification_no_preview(self) -> GqlErrorClassification: + if hasattr(self, "_gql_classification"): + return self._gql_classification + + classification = self._gql_raw_classification_no_preview + if not ( + isinstance(classification, str) + and classification + in t.cast(t.Iterable[str], iter(GqlErrorClassification)) + ): + self._gql_classification = GqlErrorClassification.UNKNOWN + else: + self._gql_classification = GqlErrorClassification(classification) + return self._gql_classification + + @property + @_preview("GQLSTATUS support is a preview feature.") + def gql_classification(self) -> GqlErrorClassification: + return self._gql_classification_no_preview + + def _get_status_diagnostic_record(self) -> dict[str, t.Any]: + if hasattr(self, "_status_diagnostic_record"): + return self._status_diagnostic_record + + self._status_diagnostic_record = dict(_UNKNOWN_GQL_DIAGNOSTIC_RECORD) + return self._status_diagnostic_record + + @property + def _diagnostic_record_no_preview(self) -> Mapping[str, t.Any]: + if hasattr(self, "_diagnostic_record"): + return self._diagnostic_record + + self._diagnostic_record = _deepcopy( + self._get_status_diagnostic_record() + ) + return self._diagnostic_record + + @property + @_preview("GQLSTATUS support is a preview feature.") + def diagnostic_record(self) -> Mapping[str, t.Any]: + return self._diagnostic_record_no_preview + + def __str__(self): + return ( + f"{{gql_status: {self._gql_status_no_preview}}} " + f"{{gql_status_description: " + f"{self._gql_status_description_no_preview}}} " + f"{{message: {self._message_no_preview}}} " + f"{{diagnostic_record: {self._diagnostic_record_no_preview}}} " + f"{{raw_classification: " + f"{self._gql_raw_classification_no_preview}}}" + ) + + # Neo4jError -class Neo4jError(Exception): +class Neo4jError(GqlError): """Raised when the Cypher engine returns an error to the client.""" - #: (str or None) The error message returned by the server. - message = None - #: (str or None) The error code returned by the server. - #: There are many Neo4j status codes, see - #: `status codes `_. - code = None - classification = None - category = None - title = None + _neo4j_code: str | None + _classification: str | None + _category: str | None + _title: str | None #: (dict) Any additional information returned by the server. - metadata = None + _metadata: dict[str, t.Any] | None _retryable = False + def __init__(self, *args) -> None: + Exception.__init__(self, *args) + self._neo4j_code = None + self._classification = None + self._category = None + self._title = None + self._metadata = None + self._message = None + + # TODO: 6.0 - do this instead to get rid of all optional attributes + # self._neo4j_code = _UNKNOWN_NEO4J_CODE + # _, self._classification, self._category, self._title = ( + # self._neo4j_code.split(".") + # ) + # self._metadata = {} + # self._init_gql() + + # TODO: 6.0 - Remove this alias @classmethod + @deprecated( + "Neo4jError.hydrate is deprecated and will be removed in a future " + "version. It is an internal method and not meant for external use." + ) def hydrate( cls, - message: str | None = None, code: str | None = None, + message: str | None = None, **metadata: t.Any, ) -> Neo4jError: - message = message or "An unknown error occurred" - code = code or "Neo.DatabaseError.General.UnknownError" + # backward compatibility: make falsy values None + code = code or None + message = message or None + return cls._hydrate_neo4j(code=code, message=message, **metadata) + + @classmethod + def _hydrate_neo4j(cls, **metadata: t.Any) -> Neo4jError: + meta_extractor = _MetaExtractor(metadata) + code = meta_extractor.str_value("code") or _UNKNOWN_NEO4J_CODE + message = meta_extractor.str_value("message") or _UNKNOWN_MESSAGE + inst = cls._basic_hydrate( + neo4j_code=code, + message=message, + ) + inst._init_gql( + gql_status=_UNKNOWN_GQL_STATUS, + message=message, + description=f"{_UNKNOWN_GQL_DESCRIPTION}. {message}", + ) + inst._metadata = meta_extractor.rest() + return inst + + @classmethod + def _hydrate_gql(cls, **metadata: t.Any) -> Neo4jError: + meta_extractor = _MetaExtractor(metadata) + gql_status = meta_extractor.str_value("gql_status") + status_description = meta_extractor.str_value("description") + message = meta_extractor.str_value("message") + if gql_status is None or status_description is None or message is None: + gql_status = _UNKNOWN_GQL_STATUS + # TODO: 6.0 - Make this fall back to _UNKNOWN_GQL_MESSAGE + message = _UNKNOWN_MESSAGE + status_description = _UNKNOWN_GQL_DESCRIPTION + neo4j_code = meta_extractor.str_value( + "neo4j_code", + _UNKNOWN_NEO4J_CODE, + ) + diagnostic_record = meta_extractor.map_value("diagnostic_record") + cause_map = meta_extractor.map_value("cause") + if cause_map is not None: + cause = cls._hydrate_cause(**cause_map) + else: + cause = None + + inst = cls._basic_hydrate( + neo4j_code=neo4j_code, + message=message, + ) + inst._init_gql( + gql_status=gql_status, + message=message, + description=status_description, + diagnostic_record=diagnostic_record, + cause=cause, + ) + inst._metadata = meta_extractor.rest() + + return inst + + @classmethod + def _basic_hydrate(cls, *, neo4j_code: str, message: str) -> Neo4jError: try: - _, classification, category, title = code.split(".") + _, classification, category, title = neo4j_code.split(".") except ValueError: classification = CLASSIFICATION_DATABASE category = "General" title = "UnknownError" else: classification_override, code_override = ERROR_REWRITE_MAP.get( - code, (None, None) + neo4j_code, (None, None) ) if classification_override is not None: classification = classification_override if code_override is not None: - code = code_override + neo4j_code = code_override - error_class = cls._extract_error_class(classification, code) + error_class: type[Neo4jError] = cls._extract_error_class( + classification, neo4j_code + ) inst = error_class(message) - inst.message = message - inst.code = code - inst.classification = classification - inst.category = category - inst.title = title - inst.metadata = metadata + inst._neo4j_code = neo4j_code + inst._classification = classification + inst._category = category + inst._title = title + inst._message = message + return inst @classmethod - def _extract_error_class(cls, classification, code): + def _extract_error_class(cls, classification, code) -> type[Neo4jError]: if classification == CLASSIFICATION_CLIENT: try: return client_errors[code] @@ -241,6 +684,76 @@ def _extract_error_class(cls, classification, code): else: return cls + @property + def message(self) -> str | None: + """ + The error message returned by the server. + + This value is only :data:`None` for locally created errors. + """ + return self._message + + @message.setter + @deprecated("Altering the message of a Neo4jError is deprecated.") + def message(self, value: str) -> None: + self._message = value + + @property + def code(self) -> str | None: + """ + The neo4j error code returned by the server. + + For example, "Neo.ClientError.Security.AuthorizationExpired". + This value is only :data:`None` for locally created errors. + """ + return self._neo4j_code + + # TODO: 6.0 - Remove this and all other deprecated setters + @code.setter + @deprecated("Altering the code of a Neo4jError is deprecated.") + def code(self, value: str) -> None: + self._neo4j_code = value + + @property + def classification(self) -> str | None: + # Undocumented, will likely be removed with support for neo4j codes + return self._classification + + @classification.setter + @deprecated("Altering the classification of Neo4jError is deprecated.") + def classification(self, value: str) -> None: + self._classification = value + + @property + def category(self) -> str | None: + # Undocumented, will likely be removed with support for neo4j codes + return self._category + + @category.setter + @deprecated("Altering the category of Neo4jError is deprecated.") + def category(self, value: str) -> None: + self._category = value + + @property + def title(self) -> str | None: + # Undocumented, will likely be removed with support for neo4j codes + return self._title + + @title.setter + @deprecated("Altering the title of Neo4jError is deprecated.") + def title(self, value: str) -> None: + self._title = value + + @property + def metadata(self) -> dict[str, t.Any] | None: + # Undocumented, might be useful for debugging + return self._metadata + + @metadata.setter + @deprecated("Altering the metadata of Neo4jError is deprecated.") + def metadata(self, value: dict[str, t.Any]) -> None: + self._metadata = value + # TODO: 6.0 - Remove this alias @deprecated( "Neo4jError.is_retriable is deprecated and will be removed in a " @@ -277,7 +790,9 @@ def is_retryable(self) -> bool: return self._retryable def _unauthenticates_all_connections(self) -> bool: - return self.code == "Neo.ClientError.Security.AuthorizationExpired" + return ( + self._neo4j_code == "Neo.ClientError.Security.AuthorizationExpired" + ) # TODO: 6.0 - Remove this alias invalidates_all_connections = deprecated( @@ -289,9 +804,10 @@ def _unauthenticates_all_connections(self) -> bool: def _is_fatal_during_discovery(self) -> bool: # checks if the code is an error that is caused by the client. In this # case the driver should fail fast during discovery. - if not isinstance(self.code, str): + code = self._neo4j_code + if not isinstance(code, str): return False - if self.code in { + if code in { "Neo.ClientError.Database.DatabaseNotFound", "Neo.ClientError.Transaction.InvalidBookmark", "Neo.ClientError.Transaction.InvalidBookmarkMixture", @@ -301,14 +817,14 @@ def _is_fatal_during_discovery(self) -> bool: }: return True return ( - self.code.startswith("Neo.ClientError.Security.") - and self.code != "Neo.ClientError.Security.AuthorizationExpired" + code.startswith("Neo.ClientError.Security.") + and code != "Neo.ClientError.Security.AuthorizationExpired" ) def _has_security_code(self) -> bool: - if self.code is None: + if self._neo4j_code is None: return False - return self.code.startswith("Neo.ClientError.Security.") + return self._neo4j_code.startswith("Neo.ClientError.Security.") # TODO: 6.0 - Remove this alias is_fatal_during_discovery = deprecated( @@ -318,9 +834,57 @@ def _has_security_code(self) -> bool: )(_is_fatal_during_discovery) def __str__(self): - if self.code or self.message: - return f"{{code: {self.code}}} {{message: {self.message}}}" - return super().__str__() + code = self._neo4j_code + message = self._message + if code or message: + return f"{{code: {code}}} {{message: {message}}}" + # TODO: 6.0 - Use gql status and status_description instead + # something like: + # return ( + # f"{{gql_status: {self.gql_status}}} " + # f"{{neo4j_code: {self.neo4j_code}}} " + # f"{{gql_status_description: {self.gql_status_description}}} " + # f"{{diagnostic_record: {self.diagnostic_record}}}" + # ) + return Exception.__str__(self) + + +class _MetaExtractor: + def __init__(self, metadata: dict[str, t.Any]): + self._metadata = metadata + + def rest(self) -> dict[str, t.Any]: + return self._metadata + + @t.overload + def str_value(self, key: str) -> str | None: ... + + @t.overload + def str_value(self, key: str, default: _T) -> str | _T: ... + + def str_value( + self, key: str, default: _T | None = None + ) -> str | _T | None: + res = self._metadata.pop(key, default) + if not isinstance(res, str): + res = default + return res + + @t.overload + def map_value(self, key: str) -> dict[str, t.Any] | None: ... + + @t.overload + def map_value(self, key: str, default: _T) -> dict[str, t.Any] | _T: ... + + def map_value( + self, key: str, default: _T | None = None + ) -> dict[str, t.Any] | _T | None: + res = self._metadata.pop(key, default) + if not ( + isinstance(res, dict) and all(isinstance(k, str) for k in res) + ): + res = default + return res # Neo4jError > ClientError @@ -433,7 +997,7 @@ class ForbiddenOnReadOnlyDatabase(TransientError): # DriverError -class DriverError(Exception): +class DriverError(GqlError): """Raised when the Driver raises an error.""" def is_retryable(self) -> bool: @@ -451,6 +1015,9 @@ def is_retryable(self) -> bool: """ return False + def __str__(self): + return Exception.__str__(self) + # DriverError > SessionError class SessionError(DriverError): @@ -531,6 +1098,13 @@ class SessionExpired(DriverError): its original parameters. """ + def __init__(self, *args): + super().__init__(*args) + self._init_gql( + gql_status="08000", + description="error: connection exception", + ) + def is_retryable(self) -> bool: return True @@ -544,6 +1118,13 @@ class ServiceUnavailable(DriverError): failure of a database service that the driver is unable to route around. """ + def __init__(self, *args): + super().__init__(*args) + self._init_gql( + gql_status="08000", + description="error: connection exception", + ) + def is_retryable(self) -> bool: return True @@ -574,6 +1155,16 @@ class IncompleteCommit(ServiceUnavailable): successfully or not. """ + def __init__(self, *args): + super().__init__(*args) + self._init_gql( + gql_status="08007", + description=( + "error: connection exception - " + "transaction resolution unknown" + ), + ) + def is_retryable(self) -> bool: return False diff --git a/testkitbackend/_async/backend.py b/testkitbackend/_async/backend.py index a6618495..76f7fb42 100644 --- a/testkitbackend/_async/backend.py +++ b/testkitbackend/_async/backend.py @@ -34,6 +34,7 @@ UnsupportedServerProduct, ) +from .. import totestkit from .._driver_logger import ( buffer_handler, log, @@ -124,6 +125,17 @@ async def process_request(self): @staticmethod def _exc_stems_from_driver(exc): + if isinstance( + exc, + ( + Neo4jError, + DriverError, + UnsupportedServerProduct, + BoltError, + MarkdAsDriverError, + ), + ): + return True stack = traceback.extract_tb(exc.__traceback__) for frame in stack[-1:1:-1]: p = Path(frame.filename) @@ -134,42 +146,37 @@ def _exc_stems_from_driver(exc): return None @staticmethod - def _exc_msg(exc, max_depth=10): - if isinstance(exc, Neo4jError) and exc.message is not None: - return str(exc.message) - - depth = 0 - res = str(exc) - while getattr(exc, "__cause__", None) is not None: - depth += 1 - if depth >= max_depth: - break - res += f"\nCaused by: {exc.__cause__!r}" - exc = exc.__cause__ - return res - - async def write_driver_exc(self, exc): - log.debug(traceback.format_exc()) + def _get_tb(exc): + return "".join( + traceback.format_exception( + type(exc), exc, getattr(exc, "__traceback__", None) + ) + ) + + def _serialize_driver_exc(self, exc): + log.debug(exc.args) + log.debug(self._get_tb(exc)) key = self.next_key() self.errors[key] = exc - payload = {"id": key, "msg": ""} + return totestkit.driver_exc(exc, id_=key) - if isinstance(exc, MarkdAsDriverError): - wrapped_exc = exc.wrapped_exc - payload["errorType"] = str(type(wrapped_exc)) - if wrapped_exc.args: - payload["msg"] = self._exc_msg(wrapped_exc.args[0]) - payload["retryable"] = False - else: - payload["errorType"] = str(type(exc)) - payload["msg"] = self._exc_msg(exc) - if isinstance(exc, Neo4jError): - payload["code"] = exc.code - payload["retryable"] = getattr(exc, "is_retryable", bool)() + @staticmethod + def _serialize_backend_error(exc): + tb = AsyncBackend._get_tb(exc) + log.error(tb) + return {"name": "BackendError", "data": {"msg": tb}} - await self.send_response("DriverError", payload) + def _serialize_exc(self, exc): + try: + if isinstance(exc, requests.FrontendError): + return {"name": "FrontendError", "data": {"msg": str(exc)}} + if self._exc_stems_from_driver(exc): + return self._serialize_driver_exc(exc) + except Exception as e: + return self._serialize_backend_error(e) + return self._serialize_backend_error(exc) async def _process(self, request): # Process a received request. @@ -190,23 +197,9 @@ async def _process(self, request): f"Backend does not support some properties of the {name} " f"request: {', '.join(unused_keys)}" ) - except ( - Neo4jError, - DriverError, - UnsupportedServerProduct, - BoltError, - MarkdAsDriverError, - ) as e: - await self.write_driver_exc(e) - except requests.FrontendError as e: - await self.send_response("FrontendError", {"msg": str(e)}) except Exception as e: - if self._exc_stems_from_driver(e): - await self.write_driver_exc(e) - else: - tb = traceback.format_exc() - log.error(tb) - await self.send_response("BackendError", {"msg": tb}) + data = self._serialize_exc(e) + await self.send_response(data["name"], data["data"]) async def send_response(self, name, data): """Send a response to backend.""" diff --git a/testkitbackend/_sync/backend.py b/testkitbackend/_sync/backend.py index a3c23891..a5464703 100644 --- a/testkitbackend/_sync/backend.py +++ b/testkitbackend/_sync/backend.py @@ -34,6 +34,7 @@ UnsupportedServerProduct, ) +from .. import totestkit from .._driver_logger import ( buffer_handler, log, @@ -124,6 +125,17 @@ def process_request(self): @staticmethod def _exc_stems_from_driver(exc): + if isinstance( + exc, + ( + Neo4jError, + DriverError, + UnsupportedServerProduct, + BoltError, + MarkdAsDriverError, + ), + ): + return True stack = traceback.extract_tb(exc.__traceback__) for frame in stack[-1:1:-1]: p = Path(frame.filename) @@ -134,42 +146,37 @@ def _exc_stems_from_driver(exc): return None @staticmethod - def _exc_msg(exc, max_depth=10): - if isinstance(exc, Neo4jError) and exc.message is not None: - return str(exc.message) - - depth = 0 - res = str(exc) - while getattr(exc, "__cause__", None) is not None: - depth += 1 - if depth >= max_depth: - break - res += f"\nCaused by: {exc.__cause__!r}" - exc = exc.__cause__ - return res - - def write_driver_exc(self, exc): - log.debug(traceback.format_exc()) + def _get_tb(exc): + return "".join( + traceback.format_exception( + type(exc), exc, getattr(exc, "__traceback__", None) + ) + ) + + def _serialize_driver_exc(self, exc): + log.debug(exc.args) + log.debug(self._get_tb(exc)) key = self.next_key() self.errors[key] = exc - payload = {"id": key, "msg": ""} + return totestkit.driver_exc(exc, id_=key) - if isinstance(exc, MarkdAsDriverError): - wrapped_exc = exc.wrapped_exc - payload["errorType"] = str(type(wrapped_exc)) - if wrapped_exc.args: - payload["msg"] = self._exc_msg(wrapped_exc.args[0]) - payload["retryable"] = False - else: - payload["errorType"] = str(type(exc)) - payload["msg"] = self._exc_msg(exc) - if isinstance(exc, Neo4jError): - payload["code"] = exc.code - payload["retryable"] = getattr(exc, "is_retryable", bool)() + @staticmethod + def _serialize_backend_error(exc): + tb = Backend._get_tb(exc) + log.error(tb) + return {"name": "BackendError", "data": {"msg": tb}} - self.send_response("DriverError", payload) + def _serialize_exc(self, exc): + try: + if isinstance(exc, requests.FrontendError): + return {"name": "FrontendError", "data": {"msg": str(exc)}} + if self._exc_stems_from_driver(exc): + return self._serialize_driver_exc(exc) + except Exception as e: + return self._serialize_backend_error(e) + return self._serialize_backend_error(exc) def _process(self, request): # Process a received request. @@ -190,23 +197,9 @@ def _process(self, request): f"Backend does not support some properties of the {name} " f"request: {', '.join(unused_keys)}" ) - except ( - Neo4jError, - DriverError, - UnsupportedServerProduct, - BoltError, - MarkdAsDriverError, - ) as e: - self.write_driver_exc(e) - except requests.FrontendError as e: - self.send_response("FrontendError", {"msg": str(e)}) except Exception as e: - if self._exc_stems_from_driver(e): - self.write_driver_exc(e) - else: - tb = traceback.format_exc() - log.error(tb) - self.send_response("BackendError", {"msg": tb}) + data = self._serialize_exc(e) + self.send_response(data["name"], data["data"]) def send_response(self, name, data): """Send a response to backend.""" diff --git a/testkitbackend/test_config.json b/testkitbackend/test_config.json index e4d8b14b..bca7f0ca 100644 --- a/testkitbackend/test_config.json +++ b/testkitbackend/test_config.json @@ -58,6 +58,7 @@ "Feature:Bolt:5.4": true, "Feature:Bolt:5.5": true, "Feature:Bolt:5.6": true, + "Feature:Bolt:5.7": true, "Feature:Bolt:Patch:UTC": true, "Feature:Impersonation": true, "Feature:TLS:1.1": "Driver blocks TLS 1.1 for security reasons.", diff --git a/testkitbackend/totestkit.py b/testkitbackend/totestkit.py index f6960ce6..eac8aec3 100644 --- a/testkitbackend/totestkit.py +++ b/testkitbackend/totestkit.py @@ -19,6 +19,11 @@ import math import neo4j +from neo4j.exceptions import ( + GqlError, + Neo4jError, + ResultFailedError, +) from neo4j.graph import ( Node, Path, @@ -36,6 +41,7 @@ ) from ._warning_check import warning_check +from .exceptions import MarkdAsDriverError def record(rec): @@ -293,3 +299,101 @@ def to(name, val): def auth_token(auth): return {"name": "AuthorizationToken", "data": vars(auth)} + + +def driver_exc(exc, id_=None): + payload = {} + if id_ is not None: + payload["id"] = id_ + payload["retryable"] = getattr(exc, "is_retryable", bool)() + if isinstance(exc, MarkdAsDriverError): + wrapped_exc = exc.wrapped_exc + payload["errorType"] = str(type(wrapped_exc)) + if wrapped_exc.args: + payload["msg"] = _exc_msg(wrapped_exc.args[0]) + else: + payload["errorType"] = str(type(exc)) + payload["msg"] = _exc_msg(exc) + if isinstance(exc, Neo4jError): + payload["code"] = exc.code + if isinstance(exc, GqlError): + with warning_check(neo4j.PreviewWarning, r".*\bGQLSTATUS\b.*"): + payload["gqlStatus"] = exc.gql_status + with warning_check(neo4j.PreviewWarning, r".*\bGQLSTATUS\b.*"): + payload["statusDescription"] = exc.gql_status_description + with warning_check(neo4j.PreviewWarning, r".*\bGQLSTATUS\b.*"): + payload["rawClassification"] = exc.gql_raw_classification + with warning_check(neo4j.PreviewWarning, r".*\bGQLSTATUS\b.*"): + payload["classification"] = exc.gql_classification + with warning_check(neo4j.PreviewWarning, r".*\bGQLSTATUS\b.*"): + payload["diagnosticRecord"] = { + k: field(v) for k, v in exc.diagnostic_record.items() + } + cause = driver_exc_cause(getattr(exc, "__cause__", None)) + if cause is not None: + payload["cause"] = cause + + return {"name": "DriverError", "data": payload} + + +def _exc_msg(exc, max_depth=10): + if isinstance(exc, Neo4jError) and exc.message is not None: + return str(exc.message) + + depth = 0 + if isinstance(exc, GqlError): + if isinstance(exc, Neo4jError): + res = str(exc.message) if exc.message is not None else str(exc) + else: + with warning_check(neo4j.PreviewWarning, r".*\bGQLSTATUS\b.*"): + msg = exc.message + res = f"{msg} - {exc!s}" if exc.args else msg + else: + res = str(exc) + while getattr(exc, "__cause__", None) is not None: + if ( + # Not including GqlError in the chain as they will be serialized + # separately in the `cause` field. + isinstance(exc.__cause__, GqlError) + # Special case for ResultFailedError: + # Always serialize the cause in the message to please TestKit. + # Else, the cause's class name will get lost (can't be serialized + # as a field in of an error cause). + and not isinstance(exc, ResultFailedError) + ): + break + depth += 1 + if depth >= max_depth: + break + res += f"\nCaused by: {exc.__cause__!r}" + exc = exc.__cause__ + return res + + +def driver_exc_cause(exc, max_depth=10): + if exc is None: + return None + if max_depth <= 0: + return None + if not isinstance(exc, GqlError): + return driver_exc_cause( + getattr(exc, "__cause__", None), max_depth=max_depth - 1 + ) + payload = {"msg": _exc_msg(exc)} + with warning_check(neo4j.PreviewWarning, r".*\bGQLSTATUS\b.*"): + payload["gqlStatus"] = exc.gql_status + with warning_check(neo4j.PreviewWarning, r".*\bGQLSTATUS\b.*"): + payload["statusDescription"] = exc.gql_status_description + with warning_check(neo4j.PreviewWarning, r".*\bGQLSTATUS\b.*"): + payload["diagnosticRecord"] = { + k: field(v) for k, v in exc.diagnostic_record.items() + } + with warning_check(neo4j.PreviewWarning, r".*\bGQLSTATUS\b.*"): + payload["classification"] = exc.gql_classification + with warning_check(neo4j.PreviewWarning, r".*\bGQLSTATUS\b.*"): + payload["rawClassification"] = exc.gql_raw_classification + cause = getattr(exc, "__cause__", None) + if cause is not None: + payload["cause"] = driver_exc_cause(cause, max_depth=max_depth - 1) + + return {"name": "GqlError", "data": payload} diff --git a/tests/iter_util.py b/tests/iter_util.py index b63f4766..69f5ba00 100644 --- a/tests/iter_util.py +++ b/tests/iter_util.py @@ -29,7 +29,8 @@ def powerset( iterable: t.Iterable[_T], - limit: int | None = None, + lower_limit: int | None = None, + upper_limit: int | None = None, ) -> t.Iterable[tuple[_T, ...]]: """ Build the powerset of an iterable. @@ -39,12 +40,19 @@ def powerset( >>> tuple(powerset([1, 2, 3])) ((), (1,), (2,), (3,), (1, 2), (1, 3), (2, 3), (1, 2, 3)) - >>> tuple(powerset([1, 2, 3], limit=2)) + >>> tuple(powerset([1, 2, 3], upper_limit=2)) ((), (1,), (2,), (3,), (1, 2), (1, 3), (2, 3)) + >>> tuple(powerset([1, 2, 3], lower_limit=2)) + ((1, 2), (1, 3), (2, 3), (1, 2, 3)) + :return: The powerset of the iterable. """ s = list(iterable) - if limit is None: - limit = len(s) - return chain.from_iterable(combinations(s, r) for r in range(limit + 1)) + if upper_limit is None: + upper_limit = len(s) + if lower_limit is None: + lower_limit = 0 + return chain.from_iterable( + combinations(s, r) for r in range(lower_limit, upper_limit + 1) + ) diff --git a/tests/unit/async_/fixtures/fake_connection.py b/tests/unit/async_/fixtures/fake_connection.py index 06d0270f..9bf96779 100644 --- a/tests/unit/async_/fixtures/fake_connection.py +++ b/tests/unit/async_/fixtures/fake_connection.py @@ -206,7 +206,7 @@ async def callback(): cb_args = default_cb_args res = cb(*cb_args) if cb_name == "on_failure": - error = Neo4jError.hydrate(**cb_args[0]) + error = Neo4jError._hydrate_gql(**cb_args[0]) # suppress in case the callback is not async with suppress(TypeError): await res diff --git a/tests/unit/async_/io/test_class_bolt.py b/tests/unit/async_/io/test_class_bolt.py index 5bbc50e8..b0ddbc96 100644 --- a/tests/unit/async_/io/test_class_bolt.py +++ b/tests/unit/async_/io/test_class_bolt.py @@ -38,7 +38,7 @@ def test_class_method_protocol_handlers(): expected_handlers = { (3, 0), (4, 1), (4, 2), (4, 3), (4, 4), - (5, 0), (5, 1), (5, 2), (5, 3), (5, 4), (5, 5), (5, 6), + (5, 0), (5, 1), (5, 2), (5, 3), (5, 4), (5, 5), (5, 6), (5, 7), } # fmt: on @@ -68,7 +68,8 @@ def test_class_method_protocol_handlers(): ((5, 4), 1), ((5, 5), 1), ((5, 6), 1), - ((5, 7), 0), + ((5, 7), 1), + ((5, 8), 0), ((6, 0), 0), ], ) @@ -91,7 +92,7 @@ def test_class_method_get_handshake(): handshake = AsyncBolt.get_handshake() assert ( handshake - == b"\x00\x06\x06\x05\x00\x02\x04\x04\x00\x00\x01\x04\x00\x00\x00\x03" + == b"\x00\x07\x07\x05\x00\x02\x04\x04\x00\x00\x01\x04\x00\x00\x00\x03" ) @@ -141,6 +142,7 @@ async def test_cancel_hello_in_open(mocker, none_auth): ((5, 4), "neo4j._async.io._bolt5.AsyncBolt5x4"), ((5, 5), "neo4j._async.io._bolt5.AsyncBolt5x5"), ((5, 6), "neo4j._async.io._bolt5.AsyncBolt5x6"), + ((5, 7), "neo4j._async.io._bolt5.AsyncBolt5x7"), ), ) @mark_async_test @@ -179,7 +181,7 @@ async def test_version_negotiation( (2, 0), (4, 0), (3, 1), - (5, 7), + (5, 8), (6, 0), ), ) @@ -187,7 +189,7 @@ async def test_version_negotiation( async def test_failing_version_negotiation(mocker, bolt_version, none_auth): supported_protocols = ( "('3.0', '4.1', '4.2', '4.3', '4.4', " - "'5.0', '5.1', '5.2', '5.3', '5.4', '5.5', '5.6')" + "'5.0', '5.1', '5.2', '5.3', '5.4', '5.5', '5.6', '5.7')" ) address = ("localhost", 7687) diff --git a/tests/unit/async_/io/test_class_bolt3.py b/tests/unit/async_/io/test_class_bolt3.py index b6ab7003..e2f56ff9 100644 --- a/tests/unit/async_/io/test_class_bolt3.py +++ b/tests/unit/async_/io/test_class_bolt3.py @@ -532,7 +532,7 @@ def raises_if_db(db): {"OPERATION_CODE": "0"}, {"OPERATION": "", "OPERATION_CODE": "0", "CURRENT_SCHEMA": "/"}, ), - limit=3, + upper_limit=3, ), ) @pytest.mark.parametrize("method", ("pull", "discard")) diff --git a/tests/unit/async_/io/test_class_bolt4x0.py b/tests/unit/async_/io/test_class_bolt4x0.py index e981bee1..fa555fd1 100644 --- a/tests/unit/async_/io/test_class_bolt4x0.py +++ b/tests/unit/async_/io/test_class_bolt4x0.py @@ -623,7 +623,7 @@ async def test_tracks_last_database(fake_socket_pair, actions): {"OPERATION_CODE": "0"}, {"OPERATION": "", "OPERATION_CODE": "0", "CURRENT_SCHEMA": "/"}, ), - limit=3, + upper_limit=3, ), ) @pytest.mark.parametrize("method", ("pull", "discard")) diff --git a/tests/unit/async_/io/test_class_bolt4x1.py b/tests/unit/async_/io/test_class_bolt4x1.py index 870cd7fb..e7ca17e0 100644 --- a/tests/unit/async_/io/test_class_bolt4x1.py +++ b/tests/unit/async_/io/test_class_bolt4x1.py @@ -645,7 +645,7 @@ async def test_tracks_last_database(fake_socket_pair, actions): {"OPERATION_CODE": "0"}, {"OPERATION": "", "OPERATION_CODE": "0", "CURRENT_SCHEMA": "/"}, ), - limit=3, + upper_limit=3, ), ) @pytest.mark.parametrize("method", ("pull", "discard")) diff --git a/tests/unit/async_/io/test_class_bolt4x2.py b/tests/unit/async_/io/test_class_bolt4x2.py index e7bf1f0a..bffb4424 100644 --- a/tests/unit/async_/io/test_class_bolt4x2.py +++ b/tests/unit/async_/io/test_class_bolt4x2.py @@ -645,7 +645,7 @@ async def test_tracks_last_database(fake_socket_pair, actions): {"OPERATION_CODE": "0"}, {"OPERATION": "", "OPERATION_CODE": "0", "CURRENT_SCHEMA": "/"}, ), - limit=3, + upper_limit=3, ), ) @pytest.mark.parametrize("method", ("pull", "discard")) diff --git a/tests/unit/async_/io/test_class_bolt4x3.py b/tests/unit/async_/io/test_class_bolt4x3.py index e3ca11d2..1f249feb 100644 --- a/tests/unit/async_/io/test_class_bolt4x3.py +++ b/tests/unit/async_/io/test_class_bolt4x3.py @@ -674,7 +674,7 @@ async def test_tracks_last_database(fake_socket_pair, actions): {"OPERATION_CODE": "0"}, {"OPERATION": "", "OPERATION_CODE": "0", "CURRENT_SCHEMA": "/"}, ), - limit=3, + upper_limit=3, ), ) @pytest.mark.parametrize("method", ("pull", "discard")) diff --git a/tests/unit/async_/io/test_class_bolt4x4.py b/tests/unit/async_/io/test_class_bolt4x4.py index a60fc2c8..695ac7c9 100644 --- a/tests/unit/async_/io/test_class_bolt4x4.py +++ b/tests/unit/async_/io/test_class_bolt4x4.py @@ -634,7 +634,7 @@ async def test_tx_timeout( {"OPERATION_CODE": "0"}, {"OPERATION": "", "OPERATION_CODE": "0", "CURRENT_SCHEMA": "/"}, ), - limit=3, + upper_limit=3, ), ) @pytest.mark.parametrize("method", ("pull", "discard")) diff --git a/tests/unit/async_/io/test_class_bolt5x0.py b/tests/unit/async_/io/test_class_bolt5x0.py index 90e4a8fe..d1f09dcc 100644 --- a/tests/unit/async_/io/test_class_bolt5x0.py +++ b/tests/unit/async_/io/test_class_bolt5x0.py @@ -698,7 +698,7 @@ async def test_tracks_last_database(fake_socket_pair, actions): {"OPERATION_CODE": "0"}, {"OPERATION": "", "OPERATION_CODE": "0", "CURRENT_SCHEMA": "/"}, ), - limit=3, + upper_limit=3, ), ) @pytest.mark.parametrize("method", ("pull", "discard")) diff --git a/tests/unit/async_/io/test_class_bolt5x1.py b/tests/unit/async_/io/test_class_bolt5x1.py index 5118dcc4..003263aa 100644 --- a/tests/unit/async_/io/test_class_bolt5x1.py +++ b/tests/unit/async_/io/test_class_bolt5x1.py @@ -752,7 +752,7 @@ async def test_tracks_last_database(fake_socket_pair, actions): {"OPERATION_CODE": "0"}, {"OPERATION": "", "OPERATION_CODE": "0", "CURRENT_SCHEMA": "/"}, ), - limit=3, + upper_limit=3, ), ) @pytest.mark.parametrize("method", ("pull", "discard")) diff --git a/tests/unit/async_/io/test_class_bolt5x2.py b/tests/unit/async_/io/test_class_bolt5x2.py index 1d7bd475..345c9a52 100644 --- a/tests/unit/async_/io/test_class_bolt5x2.py +++ b/tests/unit/async_/io/test_class_bolt5x2.py @@ -789,7 +789,7 @@ async def test_tracks_last_database(fake_socket_pair, actions): {"OPERATION_CODE": "0"}, {"OPERATION": "", "OPERATION_CODE": "0", "CURRENT_SCHEMA": "/"}, ), - limit=3, + upper_limit=3, ), ) @pytest.mark.parametrize("method", ("pull", "discard")) diff --git a/tests/unit/async_/io/test_class_bolt5x3.py b/tests/unit/async_/io/test_class_bolt5x3.py index d0d8ee4d..c70a3df4 100644 --- a/tests/unit/async_/io/test_class_bolt5x3.py +++ b/tests/unit/async_/io/test_class_bolt5x3.py @@ -676,7 +676,7 @@ async def test_tracks_last_database(fake_socket_pair, actions): {"OPERATION_CODE": "0"}, {"OPERATION": "", "OPERATION_CODE": "0", "CURRENT_SCHEMA": "/"}, ), - limit=3, + upper_limit=3, ), ) @pytest.mark.parametrize("method", ("pull", "discard")) diff --git a/tests/unit/async_/io/test_class_bolt5x4.py b/tests/unit/async_/io/test_class_bolt5x4.py index 72e07e9e..7ff21e09 100644 --- a/tests/unit/async_/io/test_class_bolt5x4.py +++ b/tests/unit/async_/io/test_class_bolt5x4.py @@ -681,7 +681,7 @@ async def test_tracks_last_database(fake_socket_pair, actions): {"OPERATION_CODE": "0"}, {"OPERATION": "", "OPERATION_CODE": "0", "CURRENT_SCHEMA": "/"}, ), - limit=3, + upper_limit=3, ), ) @pytest.mark.parametrize("method", ("pull", "discard")) diff --git a/tests/unit/async_/io/test_class_bolt5x5.py b/tests/unit/async_/io/test_class_bolt5x5.py index a7a6000d..77d748de 100644 --- a/tests/unit/async_/io/test_class_bolt5x5.py +++ b/tests/unit/async_/io/test_class_bolt5x5.py @@ -690,7 +690,7 @@ async def test_tracks_last_database(fake_socket_pair, actions): {"OPERATION": "Foo", "OPERATION_CODE": 1, "CURRENT_SCHEMA": False}, {"OPERATION": "", "OPERATION_CODE": "0", "bar": "baz"}, ), - limit=3, + upper_limit=3, ), ) @pytest.mark.parametrize("method", ("pull", "discard")) diff --git a/tests/unit/async_/io/test_class_bolt5x6.py b/tests/unit/async_/io/test_class_bolt5x6.py index 067bd47d..a5106572 100644 --- a/tests/unit/async_/io/test_class_bolt5x6.py +++ b/tests/unit/async_/io/test_class_bolt5x6.py @@ -690,7 +690,7 @@ async def test_tracks_last_database(fake_socket_pair, actions): {"OPERATION": "Foo", "OPERATION_CODE": 1, "CURRENT_SCHEMA": False}, {"OPERATION": "", "OPERATION_CODE": "0", "bar": "baz"}, ), - limit=3, + upper_limit=3, ), ) @pytest.mark.parametrize("method", ("pull", "discard")) diff --git a/tests/unit/async_/io/test_class_bolt5x7.py b/tests/unit/async_/io/test_class_bolt5x7.py new file mode 100644 index 00000000..97a8b4ea --- /dev/null +++ b/tests/unit/async_/io/test_class_bolt5x7.py @@ -0,0 +1,850 @@ +# Copyright (c) "Neo4j" +# Neo4j Sweden AB [https://neo4j.com] +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import itertools +import logging + +import pytest + +import neo4j +from neo4j._api import TelemetryAPI +from neo4j._async.config import AsyncPoolConfig +from neo4j._async.io._bolt5 import AsyncBolt5x7 +from neo4j._meta import ( + BOLT_AGENT_DICT, + USER_AGENT, +) +from neo4j.exceptions import Neo4jError + +from ...._async_compat import mark_async_test +from ....iter_util import powerset + + +@pytest.mark.parametrize("set_stale", (True, False)) +def test_conn_is_stale(fake_socket, set_stale): + address = neo4j.Address(("127.0.0.1", 7687)) + max_connection_lifetime = 0 + connection = AsyncBolt5x7( + address, fake_socket(address), max_connection_lifetime + ) + if set_stale: + connection.set_stale() + assert connection.stale() is True + + +@pytest.mark.parametrize("set_stale", (True, False)) +def test_conn_is_not_stale_if_not_enabled(fake_socket, set_stale): + address = neo4j.Address(("127.0.0.1", 7687)) + max_connection_lifetime = -1 + connection = AsyncBolt5x7( + address, fake_socket(address), max_connection_lifetime + ) + if set_stale: + connection.set_stale() + assert connection.stale() is set_stale + + +@pytest.mark.parametrize("set_stale", (True, False)) +def test_conn_is_not_stale(fake_socket, set_stale): + address = neo4j.Address(("127.0.0.1", 7687)) + max_connection_lifetime = 999999999 + connection = AsyncBolt5x7( + address, fake_socket(address), max_connection_lifetime + ) + if set_stale: + connection.set_stale() + assert connection.stale() is set_stale + + +@pytest.mark.parametrize( + ("args", "kwargs", "expected_fields"), + ( + (("", {}), {"db": "something"}, ({"db": "something"},)), + (("", {}), {"imp_user": "imposter"}, ({"imp_user": "imposter"},)), + ( + ("", {}), + {"db": "something", "imp_user": "imposter"}, + ({"db": "something", "imp_user": "imposter"},), + ), + ), +) +@mark_async_test +async def test_extra_in_begin(fake_socket, args, kwargs, expected_fields): + address = neo4j.Address(("127.0.0.1", 7687)) + socket = fake_socket(address, AsyncBolt5x7.UNPACKER_CLS) + connection = AsyncBolt5x7( + address, socket, AsyncPoolConfig.max_connection_lifetime + ) + connection.begin(*args, **kwargs) + await connection.send_all() + tag, is_fields = await socket.pop_message() + assert tag == b"\x11" + assert tuple(is_fields) == expected_fields + + +@pytest.mark.parametrize( + ("args", "kwargs", "expected_fields"), + ( + (("", {}), {"db": "something"}, ("", {}, {"db": "something"})), + ( + ("", {}), + {"imp_user": "imposter"}, + ("", {}, {"imp_user": "imposter"}), + ), + ( + ("", {}), + {"db": "something", "imp_user": "imposter"}, + ("", {}, {"db": "something", "imp_user": "imposter"}), + ), + ), +) +@mark_async_test +async def test_extra_in_run(fake_socket, args, kwargs, expected_fields): + address = neo4j.Address(("127.0.0.1", 7687)) + socket = fake_socket(address, AsyncBolt5x7.UNPACKER_CLS) + connection = AsyncBolt5x7( + address, socket, AsyncPoolConfig.max_connection_lifetime + ) + connection.run(*args, **kwargs) + await connection.send_all() + tag, is_fields = await socket.pop_message() + assert tag == b"\x10" + assert tuple(is_fields) == expected_fields + + +@mark_async_test +async def test_n_extra_in_discard(fake_socket): + address = neo4j.Address(("127.0.0.1", 7687)) + socket = fake_socket(address, AsyncBolt5x7.UNPACKER_CLS) + connection = AsyncBolt5x7( + address, socket, AsyncPoolConfig.max_connection_lifetime + ) + connection.discard(n=666) + await connection.send_all() + tag, fields = await socket.pop_message() + assert tag == b"\x2f" + assert len(fields) == 1 + assert fields[0] == {"n": 666} + + +@pytest.mark.parametrize( + ("test_input", "expected"), + [ + (666, {"n": -1, "qid": 666}), + (-1, {"n": -1}), + ], +) +@mark_async_test +async def test_qid_extra_in_discard(fake_socket, test_input, expected): + address = neo4j.Address(("127.0.0.1", 7687)) + socket = fake_socket(address, AsyncBolt5x7.UNPACKER_CLS) + connection = AsyncBolt5x7( + address, socket, AsyncPoolConfig.max_connection_lifetime + ) + connection.discard(qid=test_input) + await connection.send_all() + tag, fields = await socket.pop_message() + assert tag == b"\x2f" + assert len(fields) == 1 + assert fields[0] == expected + + +@pytest.mark.parametrize( + ("test_input", "expected"), + [ + (777, {"n": 666, "qid": 777}), + (-1, {"n": 666}), + ], +) +@mark_async_test +async def test_n_and_qid_extras_in_discard(fake_socket, test_input, expected): + address = neo4j.Address(("127.0.0.1", 7687)) + socket = fake_socket(address, AsyncBolt5x7.UNPACKER_CLS) + connection = AsyncBolt5x7( + address, socket, AsyncPoolConfig.max_connection_lifetime + ) + connection.discard(n=666, qid=test_input) + await connection.send_all() + tag, fields = await socket.pop_message() + assert tag == b"\x2f" + assert len(fields) == 1 + assert fields[0] == expected + + +@pytest.mark.parametrize( + ("test_input", "expected"), + [ + (666, {"n": 666}), + (-1, {"n": -1}), + ], +) +@mark_async_test +async def test_n_extra_in_pull(fake_socket, test_input, expected): + address = neo4j.Address(("127.0.0.1", 7687)) + socket = fake_socket(address, AsyncBolt5x7.UNPACKER_CLS) + connection = AsyncBolt5x7( + address, socket, AsyncPoolConfig.max_connection_lifetime + ) + connection.pull(n=test_input) + await connection.send_all() + tag, fields = await socket.pop_message() + assert tag == b"\x3f" + assert len(fields) == 1 + assert fields[0] == expected + + +@pytest.mark.parametrize( + ("test_input", "expected"), + [ + (777, {"n": -1, "qid": 777}), + (-1, {"n": -1}), + ], +) +@mark_async_test +async def test_qid_extra_in_pull(fake_socket, test_input, expected): + address = neo4j.Address(("127.0.0.1", 7687)) + socket = fake_socket(address, AsyncBolt5x7.UNPACKER_CLS) + connection = AsyncBolt5x7( + address, socket, AsyncPoolConfig.max_connection_lifetime + ) + connection.pull(qid=test_input) + await connection.send_all() + tag, fields = await socket.pop_message() + assert tag == b"\x3f" + assert len(fields) == 1 + assert fields[0] == expected + + +@mark_async_test +async def test_n_and_qid_extras_in_pull(fake_socket): + address = neo4j.Address(("127.0.0.1", 7687)) + socket = fake_socket(address, AsyncBolt5x7.UNPACKER_CLS) + connection = AsyncBolt5x7( + address, socket, AsyncPoolConfig.max_connection_lifetime + ) + connection.pull(n=666, qid=777) + await connection.send_all() + tag, fields = await socket.pop_message() + assert tag == b"\x3f" + assert len(fields) == 1 + assert fields[0] == {"n": 666, "qid": 777} + + +@mark_async_test +async def test_hello_passes_routing_metadata(fake_socket_pair): + address = neo4j.Address(("127.0.0.1", 7687)) + sockets = fake_socket_pair( + address, + packer_cls=AsyncBolt5x7.PACKER_CLS, + unpacker_cls=AsyncBolt5x7.UNPACKER_CLS, + ) + await sockets.server.send_message(b"\x70", {"server": "Neo4j/4.4.0"}) + await sockets.server.send_message(b"\x70", {}) + connection = AsyncBolt5x7( + address, + sockets.client, + AsyncPoolConfig.max_connection_lifetime, + routing_context={"foo": "bar"}, + ) + await connection.hello() + tag, fields = await sockets.server.pop_message() + assert tag == b"\x01" + assert len(fields) == 1 + assert fields[0]["routing"] == {"foo": "bar"} + + +@pytest.mark.parametrize("api", TelemetryAPI) +@pytest.mark.parametrize("serv_enabled", (True, False)) +@pytest.mark.parametrize("driver_disabled", (True, False)) +@mark_async_test +async def test_telemetry_message( + fake_socket, api, serv_enabled, driver_disabled +): + address = neo4j.Address(("127.0.0.1", 7687)) + socket = fake_socket(address, AsyncBolt5x7.UNPACKER_CLS) + connection = AsyncBolt5x7( + address, + socket, + AsyncPoolConfig.max_connection_lifetime, + telemetry_disabled=driver_disabled, + ) + if serv_enabled: + connection.configuration_hints["telemetry.enabled"] = True + connection.telemetry(api) + await connection.send_all() + + if serv_enabled and not driver_disabled: + tag, fields = await socket.pop_message() + assert tag == b"\x54" + assert fields == [int(api)] + else: + with pytest.raises(OSError): + await socket.pop_message() + + +@pytest.mark.parametrize( + ("hints", "valid"), + ( + ({"connection.recv_timeout_seconds": 1}, True), + ({"connection.recv_timeout_seconds": 42}, True), + ({}, True), + ({"whatever_this_is": "ignore me!"}, True), + ({"connection.recv_timeout_seconds": -1}, False), + ({"connection.recv_timeout_seconds": 0}, False), + ({"connection.recv_timeout_seconds": 2.5}, False), + ({"connection.recv_timeout_seconds": None}, False), + ({"connection.recv_timeout_seconds": False}, False), + ({"connection.recv_timeout_seconds": "1"}, False), + ), +) +@mark_async_test +async def test_hint_recv_timeout_seconds( + fake_socket_pair, hints, valid, caplog, mocker +): + address = neo4j.Address(("127.0.0.1", 7687)) + sockets = fake_socket_pair( + address, + packer_cls=AsyncBolt5x7.PACKER_CLS, + unpacker_cls=AsyncBolt5x7.UNPACKER_CLS, + ) + sockets.client.settimeout = mocker.Mock() + await sockets.server.send_message( + b"\x70", {"server": "Neo4j/4.3.4", "hints": hints} + ) + await sockets.server.send_message(b"\x70", {}) + connection = AsyncBolt5x7( + address, sockets.client, AsyncPoolConfig.max_connection_lifetime + ) + with caplog.at_level(logging.INFO): + await connection.hello() + if valid: + if "connection.recv_timeout_seconds" in hints: + sockets.client.settimeout.assert_called_once_with( + hints["connection.recv_timeout_seconds"] + ) + else: + sockets.client.settimeout.assert_not_called() + assert not any( + "recv_timeout_seconds" in msg and "invalid" in msg + for msg in caplog.messages + ) + else: + sockets.client.settimeout.assert_not_called() + assert any( + repr(hints["connection.recv_timeout_seconds"]) in msg + and "recv_timeout_seconds" in msg + and "invalid" in msg + for msg in caplog.messages + ) + + +CREDENTIALS = "+++super-secret-sauce+++" + + +@pytest.mark.parametrize( + "auth", + ( + ("user", CREDENTIALS), + neo4j.basic_auth("user", CREDENTIALS), + neo4j.kerberos_auth(CREDENTIALS), + neo4j.bearer_auth(CREDENTIALS), + neo4j.custom_auth("user", CREDENTIALS, "realm", "scheme"), + neo4j.Auth("scheme", "principal", CREDENTIALS, "realm", foo="bar"), + ), +) +@mark_async_test +async def test_credentials_are_not_logged(auth, fake_socket_pair, caplog): + address = neo4j.Address(("127.0.0.1", 7687)) + sockets = fake_socket_pair( + address, + packer_cls=AsyncBolt5x7.PACKER_CLS, + unpacker_cls=AsyncBolt5x7.UNPACKER_CLS, + ) + await sockets.server.send_message(b"\x70", {"server": "Neo4j/4.3.4"}) + await sockets.server.send_message(b"\x70", {}) + connection = AsyncBolt5x7( + address, + sockets.client, + AsyncPoolConfig.max_connection_lifetime, + auth=auth, + ) + with caplog.at_level(logging.DEBUG): + await connection.hello() + + if isinstance(auth, tuple): + auth = neo4j.basic_auth(*auth) + for field in ("scheme", "principal", "realm", "parameters"): + value = getattr(auth, field, None) + if value: + assert repr(value) in caplog.text + assert CREDENTIALS not in caplog.text + + +def _assert_notifications_in_extra(extra, expected): + for key in expected: + assert key in extra + assert extra[key] == expected[key] + + +@pytest.mark.parametrize( + ("method", "args", "extra_idx"), + ( + ("run", ("RETURN 1",), 2), + ("begin", (), 0), + ), +) +@pytest.mark.parametrize( + ("cls_min_sev", "method_min_sev"), + itertools.product((None, "WARNING", "OFF"), repeat=2), +) +@pytest.mark.parametrize( + ("cls_dis_clss", "method_dis_clss"), + itertools.product((None, [], ["HINT"], ["HINT", "DEPRECATION"]), repeat=2), +) +@mark_async_test +async def test_supports_notification_filters( + fake_socket, + method, + args, + extra_idx, + cls_min_sev, + method_min_sev, + cls_dis_clss, + method_dis_clss, +): + address = neo4j.Address(("127.0.0.1", 7687)) + socket = fake_socket(address, AsyncBolt5x7.UNPACKER_CLS) + connection = AsyncBolt5x7( + address, + socket, + AsyncPoolConfig.max_connection_lifetime, + notifications_min_severity=cls_min_sev, + notifications_disabled_classifications=cls_dis_clss, + ) + method = getattr(connection, method) + + method( + *args, + notifications_min_severity=method_min_sev, + notifications_disabled_classifications=method_dis_clss, + ) + await connection.send_all() + + _, fields = await socket.pop_message() + extra = fields[extra_idx] + expected = {} + if method_min_sev is not None: + expected["notifications_minimum_severity"] = method_min_sev + if method_dis_clss is not None: + expected["notifications_disabled_classifications"] = method_dis_clss + _assert_notifications_in_extra(extra, expected) + + +@pytest.mark.parametrize("min_sev", (None, "WARNING", "OFF")) +@pytest.mark.parametrize( + "dis_clss", (None, [], ["HINT"], ["HINT", "DEPRECATION"]) +) +@mark_async_test +async def test_hello_supports_notification_filters( + fake_socket_pair, min_sev, dis_clss +): + address = neo4j.Address(("127.0.0.1", 7687)) + sockets = fake_socket_pair( + address, + packer_cls=AsyncBolt5x7.PACKER_CLS, + unpacker_cls=AsyncBolt5x7.UNPACKER_CLS, + ) + await sockets.server.send_message(b"\x70", {"server": "Neo4j/1.2.3"}) + await sockets.server.send_message(b"\x70", {}) + connection = AsyncBolt5x7( + address, + sockets.client, + AsyncPoolConfig.max_connection_lifetime, + notifications_min_severity=min_sev, + notifications_disabled_classifications=dis_clss, + ) + + await connection.hello() + + _tag, fields = await sockets.server.pop_message() + extra = fields[0] + expected = {} + if min_sev is not None: + expected["notifications_minimum_severity"] = min_sev + if dis_clss is not None: + expected["notifications_disabled_classifications"] = dis_clss + _assert_notifications_in_extra(extra, expected) + + +@mark_async_test +@pytest.mark.parametrize( + "user_agent", (None, "test user agent", "", USER_AGENT) +) +async def test_user_agent(fake_socket_pair, user_agent): + address = neo4j.Address(("127.0.0.1", 7687)) + sockets = fake_socket_pair( + address, + packer_cls=AsyncBolt5x7.PACKER_CLS, + unpacker_cls=AsyncBolt5x7.UNPACKER_CLS, + ) + await sockets.server.send_message(b"\x70", {"server": "Neo4j/1.2.3"}) + await sockets.server.send_message(b"\x70", {}) + max_connection_lifetime = 0 + connection = AsyncBolt5x7( + address, sockets.client, max_connection_lifetime, user_agent=user_agent + ) + await connection.hello() + + _tag, fields = await sockets.server.pop_message() + extra = fields[0] + if not user_agent: + assert extra["user_agent"] == USER_AGENT + else: + assert extra["user_agent"] == user_agent + + +@mark_async_test +@pytest.mark.parametrize( + "user_agent", (None, "test user agent", "", USER_AGENT) +) +async def test_sends_bolt_agent(fake_socket_pair, user_agent): + address = neo4j.Address(("127.0.0.1", 7687)) + sockets = fake_socket_pair( + address, + packer_cls=AsyncBolt5x7.PACKER_CLS, + unpacker_cls=AsyncBolt5x7.UNPACKER_CLS, + ) + await sockets.server.send_message(b"\x70", {"server": "Neo4j/1.2.3"}) + await sockets.server.send_message(b"\x70", {}) + max_connection_lifetime = 0 + connection = AsyncBolt5x7( + address, sockets.client, max_connection_lifetime, user_agent=user_agent + ) + await connection.hello() + + _tag, fields = await sockets.server.pop_message() + extra = fields[0] + assert extra["bolt_agent"] == BOLT_AGENT_DICT + + +@mark_async_test +@pytest.mark.parametrize( + ("func", "args", "extra_idx"), + ( + ("run", ("RETURN 1",), 2), + ("begin", (), 0), + ), +) +@pytest.mark.parametrize( + ("timeout", "res"), + ( + (None, None), + (0, 0), + (0.1, 100), + (0.001, 1), + (1e-15, 1), + (0.0005, 1), + (0.0001, 1), + (1.0015, 1002), + (1.000499, 1000), + (1.0025, 1002), + (3.0005, 3000), + (3.456, 3456), + (1, 1000), + ( + -1e-15, + ValueError("Timeout must be a positive number or 0"), + ), + ( + "foo", + ValueError("Timeout must be specified as a number of seconds"), + ), + ( + [1, 2], + TypeError("Timeout must be specified as a number of seconds"), + ), + ), +) +async def test_tx_timeout( + fake_socket_pair, func, args, extra_idx, timeout, res +): + address = neo4j.Address(("127.0.0.1", 7687)) + sockets = fake_socket_pair( + address, + packer_cls=AsyncBolt5x7.PACKER_CLS, + unpacker_cls=AsyncBolt5x7.UNPACKER_CLS, + ) + await sockets.server.send_message(b"\x70", {}) + connection = AsyncBolt5x7(address, sockets.client, 0) + func = getattr(connection, func) + if isinstance(res, Exception): + with pytest.raises(type(res), match=str(res)): + func(*args, timeout=timeout) + else: + func(*args, timeout=timeout) + await connection.send_all() + _tag, fields = await sockets.server.pop_message() + extra = fields[extra_idx] + if timeout is None: + assert "tx_timeout" not in extra + else: + assert extra["tx_timeout"] == res + + +@pytest.mark.parametrize( + "actions", + itertools.combinations_with_replacement( + itertools.product( + ("run", "begin", "begin_run"), + ("reset", "commit", "rollback"), + (None, "some_db", "another_db"), + ), + 2, + ), +) +@mark_async_test +async def test_tracks_last_database(fake_socket_pair, actions): + address = neo4j.Address(("127.0.0.1", 7687)) + sockets = fake_socket_pair( + address, + packer_cls=AsyncBolt5x7.PACKER_CLS, + unpacker_cls=AsyncBolt5x7.UNPACKER_CLS, + ) + connection = AsyncBolt5x7(address, sockets.client, 0) + await sockets.server.send_message(b"\x70", {"server": "Neo4j/1.2.3"}) + await sockets.server.send_message(b"\x70", {}) + await connection.hello() + assert connection.last_database is None + for action, finish, db in actions: + await sockets.server.send_message(b"\x70", {}) + if action == "run": + connection.run("RETURN 1", db=db) + elif action == "begin": + connection.begin(db=db) + elif action == "begin_run": + connection.begin(db=db) + assert connection.last_database == db + await sockets.server.send_message(b"\x70", {}) + connection.run("RETURN 1") + else: + raise ValueError(action) + + assert connection.last_database == db + await connection.send_all() + await connection.fetch_all() + assert connection.last_database == db + + await sockets.server.send_message(b"\x70", {}) + if finish == "reset": + await connection.reset() + elif finish == "commit": + if action == "run": + connection.pull() + else: + connection.commit() + elif finish == "rollback": + if action == "run": + connection.pull() + else: + connection.rollback() + else: + raise ValueError(finish) + + await connection.send_all() + await connection.fetch_all() + + assert connection.last_database == db + + +DEFAULT_DIAG_REC_PAIRS = ( + ("OPERATION", ""), + ("OPERATION_CODE", "0"), + ("CURRENT_SCHEMA", "/"), +) + + +@pytest.mark.parametrize( + "sent_diag_records", + powerset( + ( + ..., + None, + {}, + [], + "1", + 1, + {"OPERATION_CODE": "0"}, + {"OPERATION": "", "OPERATION_CODE": "0", "CURRENT_SCHEMA": "/"}, + {"OPERATION": "Foo", "OPERATION_CODE": 1, "CURRENT_SCHEMA": False}, + {"OPERATION": "", "OPERATION_CODE": "0", "bar": "baz"}, + ), + upper_limit=3, + ), +) +@pytest.mark.parametrize("method", ("pull", "discard")) +@mark_async_test +async def test_enriches_statuses( + sent_diag_records, + method, + fake_socket_pair, +): + address = neo4j.Address(("127.0.0.1", 7687)) + sockets = fake_socket_pair( + address, + packer_cls=AsyncBolt5x7.PACKER_CLS, + unpacker_cls=AsyncBolt5x7.UNPACKER_CLS, + ) + connection = AsyncBolt5x7(address, sockets.client, 0) + + sent_metadata = { + "statuses": [ + { + "status_description": "the status description", + "description": "description", + "diagnostic_record": r, + } + if r is not ... + else { + "status_description": "the status description", + "description": "description", + } + for r in sent_diag_records + ] + } + await sockets.server.send_message(b"\x70", sent_metadata) + + received_metadata = None + + def on_success(metadata): + nonlocal received_metadata + received_metadata = metadata + + getattr(connection, method)(on_success=on_success) + await connection.send_all() + await connection.fetch_all() + + def extend_diag_record(r): + if r is ...: + return dict(DEFAULT_DIAG_REC_PAIRS) + if isinstance(r, dict): + return dict((*DEFAULT_DIAG_REC_PAIRS, *r.items())) + return r + + expected_diag_records = [extend_diag_record(r) for r in sent_diag_records] + expected_metadata = { + "statuses": [ + { + "status_description": "the status description", + "description": "description", + "diagnostic_record": r, + } + if r is not ... + else { + "status_description": "the status description", + "description": "description", + } + for r in expected_diag_records + ] + } + + assert received_metadata == expected_metadata + + +@pytest.mark.parametrize( + "sent_diag_records", + powerset( + ( + ..., + None, + {}, + [], + "1", + 1, + {"OPERATION_CODE": "0"}, + {"OPERATION": "", "OPERATION_CODE": "0", "CURRENT_SCHEMA": "/"}, + {"OPERATION": "Foo", "OPERATION_CODE": 1, "CURRENT_SCHEMA": False}, + {"OPERATION": "", "OPERATION_CODE": "0", "bar": "baz"}, + ), + lower_limit=1, + upper_limit=3, + ), +) +@mark_async_test +async def test_enriches_error_statuses( + sent_diag_records, + fake_socket_pair, +): + address = neo4j.Address(("127.0.0.1", 7687)) + sockets = fake_socket_pair( + address, + packer_cls=AsyncBolt5x7.PACKER_CLS, + unpacker_cls=AsyncBolt5x7.UNPACKER_CLS, + ) + connection = AsyncBolt5x7(address, sockets.client, 0) + sent_diag_records = [ + {**r, "_classification": "CLIENT_ERROR", "_status_parameters": {}} + if isinstance(r, dict) + else r + for r in sent_diag_records + ] + + sent_metadata = _build_error_hierarchy_metadata(sent_diag_records) + + await sockets.server.send_message(b"\x7f", sent_metadata) + + received_metadata = None + + def on_failure(metadata): + nonlocal received_metadata + received_metadata = metadata + + connection.run("RETURN 1", on_failure=on_failure) + await connection.send_all() + with pytest.raises(Neo4jError): + await connection.fetch_all() + + def extend_diag_record(r): + if r is ...: + return dict(DEFAULT_DIAG_REC_PAIRS) + if isinstance(r, dict): + return dict((*DEFAULT_DIAG_REC_PAIRS, *r.items())) + return r + + expected_diag_records = [extend_diag_record(r) for r in sent_diag_records] + expected_metadata = _build_error_hierarchy_metadata(expected_diag_records) + + assert received_metadata == expected_metadata + + +def _build_error_hierarchy_metadata(diag_records_metadata): + metadata = { + "gql_status": "FOO12", + "description": "but have you tried not doing that?!", + "message": "some people just can't be helped", + "neo4j_code": "Neo.ClientError.Generic.YouSuck", + } + if diag_records_metadata[0] is not ...: + metadata["diagnostic_record"] = diag_records_metadata[0] + current_root = metadata + for i, r in enumerate(diag_records_metadata[1:]): + current_root["cause"] = { + "description": f"error cause nr. {i + 1}", + "message": f"cause message {i + 1}", + } + current_root = current_root["cause"] + if r is not ...: + current_root["diagnostic_record"] = r + return metadata diff --git a/tests/unit/async_/io/test_neo4j_pool.py b/tests/unit/async_/io/test_neo4j_pool.py index 6fbd2de8..c0be16ad 100644 --- a/tests/unit/async_/io/test_neo4j_pool.py +++ b/tests/unit/async_/io/test_neo4j_pool.py @@ -535,11 +535,13 @@ async def test_passes_pool_config_to_connection(mocker): "error", ( ServiceUnavailable(), - Neo4jError.hydrate( - "message", "Neo.ClientError.Statement.EntityNotFound" + Neo4jError._hydrate_neo4j( + code="Neo.ClientError.Statement.EntityNotFound", + message="message", ), - Neo4jError.hydrate( - "message", "Neo.ClientError.Security.AuthorizationExpired" + Neo4jError._hydrate_neo4j( + code="Neo.ClientError.Security.AuthorizationExpired", + message="message", ), ), ) @@ -578,20 +580,20 @@ async def test_discovery_is_retried(custom_routing_opener, error): @pytest.mark.parametrize( "error", map( - lambda args: Neo4jError.hydrate(*args), + lambda args: Neo4jError._hydrate_neo4j(code=args[0], message=args[1]), ( - ("message", "Neo.ClientError.Database.DatabaseNotFound"), - ("message", "Neo.ClientError.Transaction.InvalidBookmark"), - ("message", "Neo.ClientError.Transaction.InvalidBookmarkMixture"), - ("message", "Neo.ClientError.Statement.TypeError"), - ("message", "Neo.ClientError.Statement.ArgumentError"), - ("message", "Neo.ClientError.Request.Invalid"), - ("message", "Neo.ClientError.Security.AuthenticationRateLimit"), - ("message", "Neo.ClientError.Security.CredentialsExpired"), - ("message", "Neo.ClientError.Security.Forbidden"), - ("message", "Neo.ClientError.Security.TokenExpired"), - ("message", "Neo.ClientError.Security.Unauthorized"), - ("message", "Neo.ClientError.Security.MadeUpError"), + ("Neo.ClientError.Database.DatabaseNotFound", "message"), + ("Neo.ClientError.Transaction.InvalidBookmark", "message"), + ("Neo.ClientError.Transaction.InvalidBookmarkMixture", "message"), + ("Neo.ClientError.Statement.TypeError", "message"), + ("Neo.ClientError.Statement.ArgumentError", "message"), + ("Neo.ClientError.Request.Invalid", "message"), + ("Neo.ClientError.Security.AuthenticationRateLimit", "message"), + ("Neo.ClientError.Security.CredentialsExpired", "message"), + ("Neo.ClientError.Security.Forbidden", "message"), + ("Neo.ClientError.Security.TokenExpired", "message"), + ("Neo.ClientError.Security.Unauthorized", "message"), + ("Neo.ClientError.Security.MadeUpError", "message"), ), ), ) @@ -627,7 +629,7 @@ async def test_fast_failing_discovery(custom_routing_opener, error): @pytest.mark.parametrize( ("error", "marks_unauthenticated", "fetches_new"), ( - (Neo4jError.hydrate("message", args[0]), *args[1:]) + (Neo4jError._hydrate_neo4j(code=args[0], message="message"), *args[1:]) for args in ( ("Neo.ClientError.Database.DatabaseNotFound", False, False), ("Neo.ClientError.Statement.TypeError", False, False), diff --git a/tests/unit/async_/test_auth_management.py b/tests/unit/async_/test_auth_management.py index dd731a9c..57420d3a 100644 --- a/tests/unit/async_/test_auth_management.py +++ b/tests/unit/async_/test_auth_management.py @@ -59,7 +59,7 @@ "Neo.ClientError.Security.Unauthorized", } SAMPLE_ERRORS = [ - Neo4jError.hydrate(code=code) + Neo4jError._hydrate_neo4j(code=code) for code in { "Neo.ClientError.Security.AuthenticationRateLimit", "Neo.ClientError.Security.AuthorizationExpired", diff --git a/tests/unit/async_/work/test_transaction.py b/tests/unit/async_/work/test_transaction.py index 4e30ff58..81fba42b 100644 --- a/tests/unit/async_/work/test_transaction.py +++ b/tests/unit/async_/work/test_transaction.py @@ -315,7 +315,12 @@ async def test_server_error_propagates(async_scripted_connection, error): ( "pull", { - "on_failure": ({"code": "Neo.ClientError.Made.Up"},), + "on_failure": ( + { + "neo4j_code": "Neo.ClientError.Made.Up", + "gql_status": "50N42", + }, + ), "on_summary": None, }, ) diff --git a/tests/unit/common/test_exceptions.py b/tests/unit/common/test_exceptions.py index 8cc192cd..dfb444a0 100644 --- a/tests/unit/common/test_exceptions.py +++ b/tests/unit/common/test_exceptions.py @@ -14,8 +14,16 @@ # limitations under the License. +from __future__ import annotations + +import contextlib +import re +import traceback + import pytest +import neo4j.exceptions +from neo4j import PreviewWarning from neo4j._exceptions import ( BoltError, BoltHandshakeError, @@ -28,6 +36,7 @@ CLASSIFICATION_TRANSIENT, ClientError, DatabaseError, + GqlError, Neo4jError, ServiceUnavailable, TransientError, @@ -148,8 +157,38 @@ def test_serviceunavailable_raised_from_bolt_protocol_error_with_explicit_style( assert e.value.__cause__ is error +def _assert_default_gql_error_attrs_from_neo4j_error(error: GqlError) -> None: + with pytest.warns(PreviewWarning, match="GQLSTATUS"): + assert error.gql_status == "50N42" + if error.message: + with pytest.warns(PreviewWarning, match="GQLSTATUS"): + assert error.gql_status_description == ( + "error: general processing exception - unexpected error. " + f"{error.message}" + ) + else: + with pytest.warns(PreviewWarning, match="GQLSTATUS"): + assert error.gql_status_description == ( + "error: general processing exception - unexpected error" + ) + with pytest.warns(PreviewWarning, match="GQLSTATUS"): + assert ( + error.gql_classification + == neo4j.exceptions.GqlErrorClassification.UNKNOWN + ) + with pytest.warns(PreviewWarning, match="GQLSTATUS"): + assert error.gql_raw_classification is None + with pytest.warns(PreviewWarning, match="GQLSTATUS"): + assert error.diagnostic_record == { + "CURRENT_SCHEMA": "/", + "OPERATION": "", + "OPERATION_CODE": "0", + } + assert error.__cause__ is None + + def test_neo4jerror_hydrate_with_no_args(): - error = Neo4jError.hydrate() + error = Neo4jError._hydrate_neo4j() assert isinstance(error, DatabaseError) assert error.classification == CLASSIFICATION_DATABASE @@ -158,10 +197,13 @@ def test_neo4jerror_hydrate_with_no_args(): assert error.metadata == {} assert error.message == "An unknown error occurred" assert error.code == "Neo.DatabaseError.General.UnknownError" + _assert_default_gql_error_attrs_from_neo4j_error(error) -def test_neo4jerror_hydrate_with_message_and_code_rubish(): - error = Neo4jError.hydrate(message="Test error message", code="ASDF_asdf") +def test_neo4jerror_hydrate_with_message_and_code_rubbish(): + error = Neo4jError._hydrate_neo4j( + message="Test error message", code="ASDF_asdf" + ) assert isinstance(error, DatabaseError) assert error.classification == CLASSIFICATION_DATABASE @@ -170,10 +212,11 @@ def test_neo4jerror_hydrate_with_message_and_code_rubish(): assert error.metadata == {} assert error.message == "Test error message" assert error.code == "ASDF_asdf" + _assert_default_gql_error_attrs_from_neo4j_error(error) def test_neo4jerror_hydrate_with_message_and_code_database(): - error = Neo4jError.hydrate( + error = Neo4jError._hydrate_neo4j( message="Test error message", code="Neo.DatabaseError.General.UnknownError", ) @@ -185,10 +228,11 @@ def test_neo4jerror_hydrate_with_message_and_code_database(): assert error.metadata == {} assert error.message == "Test error message" assert error.code == "Neo.DatabaseError.General.UnknownError" + _assert_default_gql_error_attrs_from_neo4j_error(error) def test_neo4jerror_hydrate_with_message_and_code_transient(): - error = Neo4jError.hydrate( + error = Neo4jError._hydrate_neo4j( message="Test error message", code="Neo.TransientError.General.TestError", ) @@ -200,10 +244,11 @@ def test_neo4jerror_hydrate_with_message_and_code_transient(): assert error.metadata == {} assert error.message == "Test error message" assert error.code == f"Neo.{CLASSIFICATION_TRANSIENT}.General.TestError" + _assert_default_gql_error_attrs_from_neo4j_error(error) def test_neo4jerror_hydrate_with_message_and_code_client(): - error = Neo4jError.hydrate( + error = Neo4jError._hydrate_neo4j( message="Test error message", code=f"Neo.{CLASSIFICATION_CLIENT}.General.TestError", ) @@ -215,6 +260,7 @@ def test_neo4jerror_hydrate_with_message_and_code_client(): assert error.metadata == {} assert error.message == "Test error message" assert error.code == f"Neo.{CLASSIFICATION_CLIENT}.General.TestError" + _assert_default_gql_error_attrs_from_neo4j_error(error) @pytest.mark.parametrize( @@ -252,9 +298,20 @@ def test_neo4jerror_hydrate_with_message_and_code_client(): ), ), ) -def test_error_rewrite(code, expected_cls, expected_code): +@pytest.mark.parametrize("mode", ("neo4j", "gql")) +def test_error_rewrite(code, expected_cls, expected_code, mode): message = "Test error message" - error = Neo4jError.hydrate(message=message, code=code) + if mode == "neo4j": + error = Neo4jError._hydrate_neo4j(message=message, code=code) + elif mode == "gql": + error = Neo4jError._hydrate_gql( + gql_status="12345", + description="error: things - they hit the fan", + message=message, + neo4j_code=code, + ) + else: + raise ValueError(f"Invalid mode {mode!r}") expected_retryable = expected_cls is TransientError assert error.__class__ is expected_cls @@ -266,49 +323,95 @@ def test_error_rewrite(code, expected_cls, expected_code): @pytest.mark.parametrize( - ("code", "message", "expected_cls", "expected_str"), + ("code", "message", "expected_cls", "expected_str", "mode"), ( - ( - "Neo.ClientError.General.UnknownError", - "Test error message", - ClientError, - "{code: Neo.ClientError.General.UnknownError} " - "{message: Test error message}", - ), - ( - None, - "Test error message", - DatabaseError, - "{code: Neo.DatabaseError.General.UnknownError} " - "{message: Test error message}", + # values that behave the same in both modes + *( + ( + *x, + mode, + ) + for mode in ("neo4j", "gql") + for x in ( + ( + "Neo.ClientError.General.UnknownError", + "Test error message", + ClientError, + ( + "{code: Neo.ClientError.General.UnknownError} " + "{message: Test error message}" + ), + ), + ( + None, + "Test error message", + DatabaseError, + ( + "{code: Neo.DatabaseError.General.UnknownError} " + "{message: Test error message}" + ), + ), + ( + "Neo.ClientError.General.UnknownError", + None, + ClientError, + ( + "{code: Neo.ClientError.General.UnknownError} " + "{message: An unknown error occurred}" + ), + ), + ) ), + # neo4j error specific behavior ( "", "Test error message", DatabaseError, - "{code: Neo.DatabaseError.General.UnknownError} " - "{message: Test error message}", + ( + "{code: Neo.DatabaseError.General.UnknownError} " + "{message: Test error message}" + ), + "neo4j", ), ( "Neo.ClientError.General.UnknownError", - None, + "", ClientError, "{code: Neo.ClientError.General.UnknownError} " "{message: An unknown error occurred}", + "neo4j", + ), + # gql error specific behavior + ( + "", + "Test error message", + DatabaseError, + "{code: } {message: Test error message}", + "gql", ), ( "Neo.ClientError.General.UnknownError", "", ClientError, - "{code: Neo.ClientError.General.UnknownError} " - "{message: An unknown error occurred}", + "{code: Neo.ClientError.General.UnknownError} {message: }", + "gql", ), ), ) def test_neo4j_error_from_server_as_str( - code, message, expected_cls, expected_str + code, message, expected_cls, expected_str, mode ): - error = Neo4jError.hydrate(code=code, message=message) + if mode == "neo4j": + error = Neo4jError._hydrate_neo4j(code=code, message=message) + elif mode == "gql": + error = Neo4jError._hydrate_gql( + gql_status="12345", + description="error: things - they hit the fan", + neo4j_code=code, + message=message, + ) + else: + raise ValueError(f"Invalid mode {mode!r}") assert type(error) is expected_cls assert str(error) == expected_str @@ -320,3 +423,328 @@ def test_neo4j_error_from_code_as_str(cls): assert type(error) is cls assert str(error) == "Generated somewhere in the driver" + + +def _make_test_gql_error( + identifier: str, + cause: GqlError | None = None, +) -> GqlError: + error = GqlError(identifier) + error._init_gql( + gql_status=f"{identifier[:5].upper():<05}", + description=f"error: $h!t went down - {identifier}", + message=identifier, + cause=cause, + ) + return error + + +def _set_error_cause(exc, cause, method="set") -> None: + if method == "set": + exc.__cause__ = cause + elif method == "raise": + with contextlib.suppress(exc.__class__): + raise exc from cause + else: + raise ValueError(f"Invalid cause set method {method!r}") + + +_CYCLIC_CAUSE_MARKER = object() + + +def _assert_error_chain( + exc: BaseException, + expected: list[object], +) -> None: + assert isinstance(exc, BaseException) + + collection_root: BaseException | None = exc + actual_chain: list[object] = [exc] + actual_chain_ids = [id(exc)] + while collection_root is not None: + cause = getattr(collection_root, "__cause__", None) + if id(cause) in actual_chain_ids: + actual_chain.append(_CYCLIC_CAUSE_MARKER) + actual_chain_ids.append(id(_CYCLIC_CAUSE_MARKER)) + break + actual_chain.append(cause) + actual_chain_ids.append(id(cause)) + collection_root = cause + + assert actual_chain_ids == list(map(id, expected)) + + expected_lines = [ + str(exc) + for exc in expected + if exc is not None and exc is not _CYCLIC_CAUSE_MARKER + ] + expected_lines.reverse() + exc_fmt = traceback.format_exception(type(exc), exc, exc.__traceback__) + for line in exc_fmt: + if not expected_lines: + break + if expected_lines[0] in line: + expected_lines.pop(0) + if expected_lines: + traceback_fmt = "".join(exc_fmt) + pytest.fail( + f"Expected lines not found: {expected_lines} in traceback:\n" + f"{traceback_fmt}" + ) + + +def test_cause_chain_extension_no_cause() -> None: + root = _make_test_gql_error("root") + + _assert_error_chain(root, [root, None]) + + +def test_cause_chain_extension_only_gql_cause() -> None: + root_cause = _make_test_gql_error("rootCause") + root = _make_test_gql_error("root", cause=root_cause) + + _assert_error_chain(root, [root, root_cause, None]) + + +@pytest.mark.parametrize("local_cause_method", ("raise", "set")) +def test_cause_chain_extension_only_local_cause(local_cause_method) -> None: + root_cause = ClientError("rootCause") + root = _make_test_gql_error("root") + _set_error_cause(root, root_cause, local_cause_method) + + _assert_error_chain(root, [root, root_cause, None]) + + +@pytest.mark.parametrize("local_cause_method", ("raise", "set")) +def test_cause_chain_extension_multiple_causes(local_cause_method) -> None: + root4_cause2 = _make_test_gql_error("r4c2") + root4_cause1 = _make_test_gql_error("r4c1", cause=root4_cause2) + root4 = _make_test_gql_error("root4", cause=root4_cause1) + root3 = ClientError("root3") + _set_error_cause(root3, root4, local_cause_method) + root2_cause3 = _make_test_gql_error("r2c3") + root2_cause2 = _make_test_gql_error("r2c2", cause=root2_cause3) + root2_cause1 = _make_test_gql_error("r2c1", cause=root2_cause2) + root2 = _make_test_gql_error("root2", cause=root2_cause1) + _set_error_cause(root2, root3, local_cause_method) + root1_cause2 = _make_test_gql_error("r1c2") + root1_cause1 = _make_test_gql_error("r1c1", cause=root1_cause2) + root1 = _make_test_gql_error("root1", cause=root1_cause1) + _set_error_cause(root1, root2, local_cause_method) + + _assert_error_chain( + root1, [ + root1, root1_cause1, root1_cause2, + root2, root2_cause1, root2_cause2, root2_cause3, + root3, + root4, root4_cause1, root4_cause2, + None, + ], + ) # fmt: skip + + +@pytest.mark.parametrize("local_cause_method", ("raise", "set")) +def test_cause_chain_extension_circular_local_causes( + local_cause_method, +) -> None: + root6 = ClientError("root6") + root5 = _make_test_gql_error("root5") + _set_error_cause(root5, root6, local_cause_method) + root4_cause = _make_test_gql_error("r4c") + root4 = _make_test_gql_error("root4", cause=root4_cause) + _set_error_cause(root4, root5, local_cause_method) + root3 = ClientError("root3") + _set_error_cause(root3, root4, local_cause_method) + root2 = _make_test_gql_error("root2") + _set_error_cause(root2, root3, local_cause_method) + root1 = ClientError("root1") + _set_error_cause(root1, root2, local_cause_method) + _set_error_cause(root6, root1, local_cause_method) + + _assert_error_chain( + root1, + [ + root1, + root2, + root3, + root4, + root4_cause, + root5, + root6, + _CYCLIC_CAUSE_MARKER, + ], + ) + + +_DEFAULT_GQL_ERROR_ATTRIBUTES = { + "code": "Neo.DatabaseError.General.UnknownError", + "classification": "DatabaseError", + "category": "General", + "title": "UnknownError", + "message": "An unknown error occurred", + "gql_status": "50N42", + "gql_status_description": ( + "error: general processing exception - unexpected error" + ), + "gql_classification": neo4j.exceptions.GqlErrorClassification.UNKNOWN, + "gql_raw_classification": None, + "diagnostic_record": { + "CURRENT_SCHEMA": "/", + "OPERATION": "", + "OPERATION_CODE": "0", + }, + "__cause__": None, +} + + +@pytest.mark.parametrize( + ("metadata", "attributes"), + ( + # all default values + ( + {}, + _DEFAULT_GQL_ERROR_ATTRIBUTES, + ), + # example from ADR + ( + { + "gql_status": "01N00", + "message": "01EXAMPLE you have failed something", + "description": "client error - example error. Message", + "neo4j_code": "Neo.Example.Failure.Code", + "diagnostic_record": { + "CURRENT_SCHEMA": "", + "OPERATION": "", + "OPERATION_CODE": "", + "_classification": "CLIENT_ERROR", + "_status_parameters": {}, + }, + }, + { + "code": "Neo.Example.Failure.Code", + "classification": "Example", + "category": "Failure", + "title": "Code", + "message": "01EXAMPLE you have failed something", + "gql_status": "01N00", + "gql_status_description": ( + "client error - example error. Message" + ), + "gql_classification": ( + neo4j.exceptions.GqlErrorClassification.CLIENT_ERROR + ), + "gql_raw_classification": "CLIENT_ERROR", + "diagnostic_record": { + "CURRENT_SCHEMA": "", + "OPERATION": "", + "OPERATION_CODE": "", + "_classification": "CLIENT_ERROR", + "_status_parameters": {}, + }, + "__cause__": None, + }, + ), + # garbage diagnostic record + ( + { + "diagnostic_record": { + "CURRENT_SCHEMA": 1.5, + "OPERATION": False, + "_classification": ["whelp", None], + "_🤡": "🎈", + "foo": {"bar": "baz"}, + }, + }, + { + **_DEFAULT_GQL_ERROR_ATTRIBUTES, + "gql_classification": ( + neo4j.exceptions.GqlErrorClassification.UNKNOWN + ), + "gql_raw_classification": None, + "diagnostic_record": { + "CURRENT_SCHEMA": 1.5, + "OPERATION": False, + "_classification": ["whelp", None], + "_🤡": "🎈", + "foo": {"bar": "baz"}, + }, + }, + ), + ( + { + "diagnostic_record": { + "_classification": "SOME_FUTURE_CLASSIFICATION", + }, + }, + { + **_DEFAULT_GQL_ERROR_ATTRIBUTES, + "gql_classification": ( + neo4j.exceptions.GqlErrorClassification.UNKNOWN + ), + "gql_raw_classification": "SOME_FUTURE_CLASSIFICATION", + "diagnostic_record": { + "_classification": "SOME_FUTURE_CLASSIFICATION", + }, + }, + ), + ), +) +def test_gql_hydration(metadata, attributes): + # TODO: test causes + error = Neo4jError._hydrate_gql(**metadata) + + preview_attrs = { + "gql_status", + "gql_status_description", + "gql_classification", + "gql_raw_classification", + "diagnostic_record", + } + + for attr in ( + "code", + "classification", + "category", + "title", + "message", + "gql_status", + "gql_status_description", + "gql_classification", + "gql_raw_classification", + "diagnostic_record", + "__cause__", + ): + expected_value = attributes[attr] + if attr in preview_attrs: + with pytest.warns(PreviewWarning, match="GQLSTATUS"): + actual_value = getattr(error, attr) + else: + actual_value = getattr(error, attr) + assert actual_value == expected_value + + +@pytest.mark.parametrize( + "attr", + ( + "code", + "classification", + "category", + "title", + "message", + "metadata", + ), +) +def test_deprecated_setter(attr): + obj = object() + error = Neo4jError() + + with pytest.warns( + DeprecationWarning, + match=re.compile( + rf".*\baltering\b.*\b{attr}\b.*", + flags=re.IGNORECASE, + ), + ): + setattr(error, attr, obj) + + assert getattr(error, attr) is obj diff --git a/tests/unit/common/test_record.py b/tests/unit/common/test_record.py index 0c515654..b3c38301 100644 --- a/tests/unit/common/test_record.py +++ b/tests/unit/common/test_record.py @@ -591,5 +591,6 @@ class TestError(Exception): with pytest.raises(BrokenRecordError) as raised: accessor(r) exc_value = raised.value + assert exc_value.__cause__ is not None assert exc_value.__cause__ is exc assert list(traceback.walk_tb(exc_value.__cause__.__traceback__)) == frames diff --git a/tests/unit/common/work/test_summary.py b/tests/unit/common/work/test_summary.py index 58adc9c9..74f2059a 100644 --- a/tests/unit/common/work/test_summary.py +++ b/tests/unit/common/work/test_summary.py @@ -889,6 +889,7 @@ def test_summary_result_counters(summary_args_kwargs, counters_set) -> None: ((5, 4), "t_first"), ((5, 5), "t_first"), ((5, 6), "t_first"), + ((5, 7), "t_first"), ), ) def test_summary_result_available_after( @@ -925,6 +926,7 @@ def test_summary_result_available_after( ((5, 4), "t_last"), ((5, 5), "t_last"), ((5, 6), "t_last"), + ((5, 7), "t_last"), ), ) def test_summary_result_consumed_after( diff --git a/tests/unit/sync/fixtures/fake_connection.py b/tests/unit/sync/fixtures/fake_connection.py index 08f63df6..8785badb 100644 --- a/tests/unit/sync/fixtures/fake_connection.py +++ b/tests/unit/sync/fixtures/fake_connection.py @@ -206,7 +206,7 @@ def callback(): cb_args = default_cb_args res = cb(*cb_args) if cb_name == "on_failure": - error = Neo4jError.hydrate(**cb_args[0]) + error = Neo4jError._hydrate_gql(**cb_args[0]) # suppress in case the callback is not async with suppress(TypeError): res diff --git a/tests/unit/sync/io/test_class_bolt.py b/tests/unit/sync/io/test_class_bolt.py index f82d441d..f3b06303 100644 --- a/tests/unit/sync/io/test_class_bolt.py +++ b/tests/unit/sync/io/test_class_bolt.py @@ -38,7 +38,7 @@ def test_class_method_protocol_handlers(): expected_handlers = { (3, 0), (4, 1), (4, 2), (4, 3), (4, 4), - (5, 0), (5, 1), (5, 2), (5, 3), (5, 4), (5, 5), (5, 6), + (5, 0), (5, 1), (5, 2), (5, 3), (5, 4), (5, 5), (5, 6), (5, 7), } # fmt: on @@ -68,7 +68,8 @@ def test_class_method_protocol_handlers(): ((5, 4), 1), ((5, 5), 1), ((5, 6), 1), - ((5, 7), 0), + ((5, 7), 1), + ((5, 8), 0), ((6, 0), 0), ], ) @@ -91,7 +92,7 @@ def test_class_method_get_handshake(): handshake = Bolt.get_handshake() assert ( handshake - == b"\x00\x06\x06\x05\x00\x02\x04\x04\x00\x00\x01\x04\x00\x00\x00\x03" + == b"\x00\x07\x07\x05\x00\x02\x04\x04\x00\x00\x01\x04\x00\x00\x00\x03" ) @@ -141,6 +142,7 @@ def test_cancel_hello_in_open(mocker, none_auth): ((5, 4), "neo4j._sync.io._bolt5.Bolt5x4"), ((5, 5), "neo4j._sync.io._bolt5.Bolt5x5"), ((5, 6), "neo4j._sync.io._bolt5.Bolt5x6"), + ((5, 7), "neo4j._sync.io._bolt5.Bolt5x7"), ), ) @mark_sync_test @@ -179,7 +181,7 @@ def test_version_negotiation( (2, 0), (4, 0), (3, 1), - (5, 7), + (5, 8), (6, 0), ), ) @@ -187,7 +189,7 @@ def test_version_negotiation( def test_failing_version_negotiation(mocker, bolt_version, none_auth): supported_protocols = ( "('3.0', '4.1', '4.2', '4.3', '4.4', " - "'5.0', '5.1', '5.2', '5.3', '5.4', '5.5', '5.6')" + "'5.0', '5.1', '5.2', '5.3', '5.4', '5.5', '5.6', '5.7')" ) address = ("localhost", 7687) diff --git a/tests/unit/sync/io/test_class_bolt3.py b/tests/unit/sync/io/test_class_bolt3.py index d6b1b3c1..ba80ce81 100644 --- a/tests/unit/sync/io/test_class_bolt3.py +++ b/tests/unit/sync/io/test_class_bolt3.py @@ -532,7 +532,7 @@ def raises_if_db(db): {"OPERATION_CODE": "0"}, {"OPERATION": "", "OPERATION_CODE": "0", "CURRENT_SCHEMA": "/"}, ), - limit=3, + upper_limit=3, ), ) @pytest.mark.parametrize("method", ("pull", "discard")) diff --git a/tests/unit/sync/io/test_class_bolt4x0.py b/tests/unit/sync/io/test_class_bolt4x0.py index 619902f6..a0ad36e8 100644 --- a/tests/unit/sync/io/test_class_bolt4x0.py +++ b/tests/unit/sync/io/test_class_bolt4x0.py @@ -623,7 +623,7 @@ def test_tracks_last_database(fake_socket_pair, actions): {"OPERATION_CODE": "0"}, {"OPERATION": "", "OPERATION_CODE": "0", "CURRENT_SCHEMA": "/"}, ), - limit=3, + upper_limit=3, ), ) @pytest.mark.parametrize("method", ("pull", "discard")) diff --git a/tests/unit/sync/io/test_class_bolt4x1.py b/tests/unit/sync/io/test_class_bolt4x1.py index d89bcd5d..c4b0208a 100644 --- a/tests/unit/sync/io/test_class_bolt4x1.py +++ b/tests/unit/sync/io/test_class_bolt4x1.py @@ -645,7 +645,7 @@ def test_tracks_last_database(fake_socket_pair, actions): {"OPERATION_CODE": "0"}, {"OPERATION": "", "OPERATION_CODE": "0", "CURRENT_SCHEMA": "/"}, ), - limit=3, + upper_limit=3, ), ) @pytest.mark.parametrize("method", ("pull", "discard")) diff --git a/tests/unit/sync/io/test_class_bolt4x2.py b/tests/unit/sync/io/test_class_bolt4x2.py index bdcee8c6..b6ac961a 100644 --- a/tests/unit/sync/io/test_class_bolt4x2.py +++ b/tests/unit/sync/io/test_class_bolt4x2.py @@ -645,7 +645,7 @@ def test_tracks_last_database(fake_socket_pair, actions): {"OPERATION_CODE": "0"}, {"OPERATION": "", "OPERATION_CODE": "0", "CURRENT_SCHEMA": "/"}, ), - limit=3, + upper_limit=3, ), ) @pytest.mark.parametrize("method", ("pull", "discard")) diff --git a/tests/unit/sync/io/test_class_bolt4x3.py b/tests/unit/sync/io/test_class_bolt4x3.py index 333eccec..c5da8700 100644 --- a/tests/unit/sync/io/test_class_bolt4x3.py +++ b/tests/unit/sync/io/test_class_bolt4x3.py @@ -674,7 +674,7 @@ def test_tracks_last_database(fake_socket_pair, actions): {"OPERATION_CODE": "0"}, {"OPERATION": "", "OPERATION_CODE": "0", "CURRENT_SCHEMA": "/"}, ), - limit=3, + upper_limit=3, ), ) @pytest.mark.parametrize("method", ("pull", "discard")) diff --git a/tests/unit/sync/io/test_class_bolt4x4.py b/tests/unit/sync/io/test_class_bolt4x4.py index 5eb4cbb7..164372b0 100644 --- a/tests/unit/sync/io/test_class_bolt4x4.py +++ b/tests/unit/sync/io/test_class_bolt4x4.py @@ -634,7 +634,7 @@ def test_tx_timeout( {"OPERATION_CODE": "0"}, {"OPERATION": "", "OPERATION_CODE": "0", "CURRENT_SCHEMA": "/"}, ), - limit=3, + upper_limit=3, ), ) @pytest.mark.parametrize("method", ("pull", "discard")) diff --git a/tests/unit/sync/io/test_class_bolt5x0.py b/tests/unit/sync/io/test_class_bolt5x0.py index 4b3677c8..6f26b97a 100644 --- a/tests/unit/sync/io/test_class_bolt5x0.py +++ b/tests/unit/sync/io/test_class_bolt5x0.py @@ -698,7 +698,7 @@ def test_tracks_last_database(fake_socket_pair, actions): {"OPERATION_CODE": "0"}, {"OPERATION": "", "OPERATION_CODE": "0", "CURRENT_SCHEMA": "/"}, ), - limit=3, + upper_limit=3, ), ) @pytest.mark.parametrize("method", ("pull", "discard")) diff --git a/tests/unit/sync/io/test_class_bolt5x1.py b/tests/unit/sync/io/test_class_bolt5x1.py index cd2219fd..dfe638a9 100644 --- a/tests/unit/sync/io/test_class_bolt5x1.py +++ b/tests/unit/sync/io/test_class_bolt5x1.py @@ -752,7 +752,7 @@ def test_tracks_last_database(fake_socket_pair, actions): {"OPERATION_CODE": "0"}, {"OPERATION": "", "OPERATION_CODE": "0", "CURRENT_SCHEMA": "/"}, ), - limit=3, + upper_limit=3, ), ) @pytest.mark.parametrize("method", ("pull", "discard")) diff --git a/tests/unit/sync/io/test_class_bolt5x2.py b/tests/unit/sync/io/test_class_bolt5x2.py index 017df421..5dc09be8 100644 --- a/tests/unit/sync/io/test_class_bolt5x2.py +++ b/tests/unit/sync/io/test_class_bolt5x2.py @@ -789,7 +789,7 @@ def test_tracks_last_database(fake_socket_pair, actions): {"OPERATION_CODE": "0"}, {"OPERATION": "", "OPERATION_CODE": "0", "CURRENT_SCHEMA": "/"}, ), - limit=3, + upper_limit=3, ), ) @pytest.mark.parametrize("method", ("pull", "discard")) diff --git a/tests/unit/sync/io/test_class_bolt5x3.py b/tests/unit/sync/io/test_class_bolt5x3.py index 533006cb..af852710 100644 --- a/tests/unit/sync/io/test_class_bolt5x3.py +++ b/tests/unit/sync/io/test_class_bolt5x3.py @@ -676,7 +676,7 @@ def test_tracks_last_database(fake_socket_pair, actions): {"OPERATION_CODE": "0"}, {"OPERATION": "", "OPERATION_CODE": "0", "CURRENT_SCHEMA": "/"}, ), - limit=3, + upper_limit=3, ), ) @pytest.mark.parametrize("method", ("pull", "discard")) diff --git a/tests/unit/sync/io/test_class_bolt5x4.py b/tests/unit/sync/io/test_class_bolt5x4.py index 34585091..5773d1f6 100644 --- a/tests/unit/sync/io/test_class_bolt5x4.py +++ b/tests/unit/sync/io/test_class_bolt5x4.py @@ -681,7 +681,7 @@ def test_tracks_last_database(fake_socket_pair, actions): {"OPERATION_CODE": "0"}, {"OPERATION": "", "OPERATION_CODE": "0", "CURRENT_SCHEMA": "/"}, ), - limit=3, + upper_limit=3, ), ) @pytest.mark.parametrize("method", ("pull", "discard")) diff --git a/tests/unit/sync/io/test_class_bolt5x5.py b/tests/unit/sync/io/test_class_bolt5x5.py index 90d5fd21..361a9c14 100644 --- a/tests/unit/sync/io/test_class_bolt5x5.py +++ b/tests/unit/sync/io/test_class_bolt5x5.py @@ -690,7 +690,7 @@ def test_tracks_last_database(fake_socket_pair, actions): {"OPERATION": "Foo", "OPERATION_CODE": 1, "CURRENT_SCHEMA": False}, {"OPERATION": "", "OPERATION_CODE": "0", "bar": "baz"}, ), - limit=3, + upper_limit=3, ), ) @pytest.mark.parametrize("method", ("pull", "discard")) diff --git a/tests/unit/sync/io/test_class_bolt5x6.py b/tests/unit/sync/io/test_class_bolt5x6.py index fcde8a7c..15f37872 100644 --- a/tests/unit/sync/io/test_class_bolt5x6.py +++ b/tests/unit/sync/io/test_class_bolt5x6.py @@ -690,7 +690,7 @@ def test_tracks_last_database(fake_socket_pair, actions): {"OPERATION": "Foo", "OPERATION_CODE": 1, "CURRENT_SCHEMA": False}, {"OPERATION": "", "OPERATION_CODE": "0", "bar": "baz"}, ), - limit=3, + upper_limit=3, ), ) @pytest.mark.parametrize("method", ("pull", "discard")) diff --git a/tests/unit/sync/io/test_class_bolt5x7.py b/tests/unit/sync/io/test_class_bolt5x7.py new file mode 100644 index 00000000..cf999cc6 --- /dev/null +++ b/tests/unit/sync/io/test_class_bolt5x7.py @@ -0,0 +1,850 @@ +# Copyright (c) "Neo4j" +# Neo4j Sweden AB [https://neo4j.com] +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import itertools +import logging + +import pytest + +import neo4j +from neo4j._api import TelemetryAPI +from neo4j._meta import ( + BOLT_AGENT_DICT, + USER_AGENT, +) +from neo4j._sync.config import PoolConfig +from neo4j._sync.io._bolt5 import Bolt5x7 +from neo4j.exceptions import Neo4jError + +from ...._async_compat import mark_sync_test +from ....iter_util import powerset + + +@pytest.mark.parametrize("set_stale", (True, False)) +def test_conn_is_stale(fake_socket, set_stale): + address = neo4j.Address(("127.0.0.1", 7687)) + max_connection_lifetime = 0 + connection = Bolt5x7( + address, fake_socket(address), max_connection_lifetime + ) + if set_stale: + connection.set_stale() + assert connection.stale() is True + + +@pytest.mark.parametrize("set_stale", (True, False)) +def test_conn_is_not_stale_if_not_enabled(fake_socket, set_stale): + address = neo4j.Address(("127.0.0.1", 7687)) + max_connection_lifetime = -1 + connection = Bolt5x7( + address, fake_socket(address), max_connection_lifetime + ) + if set_stale: + connection.set_stale() + assert connection.stale() is set_stale + + +@pytest.mark.parametrize("set_stale", (True, False)) +def test_conn_is_not_stale(fake_socket, set_stale): + address = neo4j.Address(("127.0.0.1", 7687)) + max_connection_lifetime = 999999999 + connection = Bolt5x7( + address, fake_socket(address), max_connection_lifetime + ) + if set_stale: + connection.set_stale() + assert connection.stale() is set_stale + + +@pytest.mark.parametrize( + ("args", "kwargs", "expected_fields"), + ( + (("", {}), {"db": "something"}, ({"db": "something"},)), + (("", {}), {"imp_user": "imposter"}, ({"imp_user": "imposter"},)), + ( + ("", {}), + {"db": "something", "imp_user": "imposter"}, + ({"db": "something", "imp_user": "imposter"},), + ), + ), +) +@mark_sync_test +def test_extra_in_begin(fake_socket, args, kwargs, expected_fields): + address = neo4j.Address(("127.0.0.1", 7687)) + socket = fake_socket(address, Bolt5x7.UNPACKER_CLS) + connection = Bolt5x7( + address, socket, PoolConfig.max_connection_lifetime + ) + connection.begin(*args, **kwargs) + connection.send_all() + tag, is_fields = socket.pop_message() + assert tag == b"\x11" + assert tuple(is_fields) == expected_fields + + +@pytest.mark.parametrize( + ("args", "kwargs", "expected_fields"), + ( + (("", {}), {"db": "something"}, ("", {}, {"db": "something"})), + ( + ("", {}), + {"imp_user": "imposter"}, + ("", {}, {"imp_user": "imposter"}), + ), + ( + ("", {}), + {"db": "something", "imp_user": "imposter"}, + ("", {}, {"db": "something", "imp_user": "imposter"}), + ), + ), +) +@mark_sync_test +def test_extra_in_run(fake_socket, args, kwargs, expected_fields): + address = neo4j.Address(("127.0.0.1", 7687)) + socket = fake_socket(address, Bolt5x7.UNPACKER_CLS) + connection = Bolt5x7( + address, socket, PoolConfig.max_connection_lifetime + ) + connection.run(*args, **kwargs) + connection.send_all() + tag, is_fields = socket.pop_message() + assert tag == b"\x10" + assert tuple(is_fields) == expected_fields + + +@mark_sync_test +def test_n_extra_in_discard(fake_socket): + address = neo4j.Address(("127.0.0.1", 7687)) + socket = fake_socket(address, Bolt5x7.UNPACKER_CLS) + connection = Bolt5x7( + address, socket, PoolConfig.max_connection_lifetime + ) + connection.discard(n=666) + connection.send_all() + tag, fields = socket.pop_message() + assert tag == b"\x2f" + assert len(fields) == 1 + assert fields[0] == {"n": 666} + + +@pytest.mark.parametrize( + ("test_input", "expected"), + [ + (666, {"n": -1, "qid": 666}), + (-1, {"n": -1}), + ], +) +@mark_sync_test +def test_qid_extra_in_discard(fake_socket, test_input, expected): + address = neo4j.Address(("127.0.0.1", 7687)) + socket = fake_socket(address, Bolt5x7.UNPACKER_CLS) + connection = Bolt5x7( + address, socket, PoolConfig.max_connection_lifetime + ) + connection.discard(qid=test_input) + connection.send_all() + tag, fields = socket.pop_message() + assert tag == b"\x2f" + assert len(fields) == 1 + assert fields[0] == expected + + +@pytest.mark.parametrize( + ("test_input", "expected"), + [ + (777, {"n": 666, "qid": 777}), + (-1, {"n": 666}), + ], +) +@mark_sync_test +def test_n_and_qid_extras_in_discard(fake_socket, test_input, expected): + address = neo4j.Address(("127.0.0.1", 7687)) + socket = fake_socket(address, Bolt5x7.UNPACKER_CLS) + connection = Bolt5x7( + address, socket, PoolConfig.max_connection_lifetime + ) + connection.discard(n=666, qid=test_input) + connection.send_all() + tag, fields = socket.pop_message() + assert tag == b"\x2f" + assert len(fields) == 1 + assert fields[0] == expected + + +@pytest.mark.parametrize( + ("test_input", "expected"), + [ + (666, {"n": 666}), + (-1, {"n": -1}), + ], +) +@mark_sync_test +def test_n_extra_in_pull(fake_socket, test_input, expected): + address = neo4j.Address(("127.0.0.1", 7687)) + socket = fake_socket(address, Bolt5x7.UNPACKER_CLS) + connection = Bolt5x7( + address, socket, PoolConfig.max_connection_lifetime + ) + connection.pull(n=test_input) + connection.send_all() + tag, fields = socket.pop_message() + assert tag == b"\x3f" + assert len(fields) == 1 + assert fields[0] == expected + + +@pytest.mark.parametrize( + ("test_input", "expected"), + [ + (777, {"n": -1, "qid": 777}), + (-1, {"n": -1}), + ], +) +@mark_sync_test +def test_qid_extra_in_pull(fake_socket, test_input, expected): + address = neo4j.Address(("127.0.0.1", 7687)) + socket = fake_socket(address, Bolt5x7.UNPACKER_CLS) + connection = Bolt5x7( + address, socket, PoolConfig.max_connection_lifetime + ) + connection.pull(qid=test_input) + connection.send_all() + tag, fields = socket.pop_message() + assert tag == b"\x3f" + assert len(fields) == 1 + assert fields[0] == expected + + +@mark_sync_test +def test_n_and_qid_extras_in_pull(fake_socket): + address = neo4j.Address(("127.0.0.1", 7687)) + socket = fake_socket(address, Bolt5x7.UNPACKER_CLS) + connection = Bolt5x7( + address, socket, PoolConfig.max_connection_lifetime + ) + connection.pull(n=666, qid=777) + connection.send_all() + tag, fields = socket.pop_message() + assert tag == b"\x3f" + assert len(fields) == 1 + assert fields[0] == {"n": 666, "qid": 777} + + +@mark_sync_test +def test_hello_passes_routing_metadata(fake_socket_pair): + address = neo4j.Address(("127.0.0.1", 7687)) + sockets = fake_socket_pair( + address, + packer_cls=Bolt5x7.PACKER_CLS, + unpacker_cls=Bolt5x7.UNPACKER_CLS, + ) + sockets.server.send_message(b"\x70", {"server": "Neo4j/4.4.0"}) + sockets.server.send_message(b"\x70", {}) + connection = Bolt5x7( + address, + sockets.client, + PoolConfig.max_connection_lifetime, + routing_context={"foo": "bar"}, + ) + connection.hello() + tag, fields = sockets.server.pop_message() + assert tag == b"\x01" + assert len(fields) == 1 + assert fields[0]["routing"] == {"foo": "bar"} + + +@pytest.mark.parametrize("api", TelemetryAPI) +@pytest.mark.parametrize("serv_enabled", (True, False)) +@pytest.mark.parametrize("driver_disabled", (True, False)) +@mark_sync_test +def test_telemetry_message( + fake_socket, api, serv_enabled, driver_disabled +): + address = neo4j.Address(("127.0.0.1", 7687)) + socket = fake_socket(address, Bolt5x7.UNPACKER_CLS) + connection = Bolt5x7( + address, + socket, + PoolConfig.max_connection_lifetime, + telemetry_disabled=driver_disabled, + ) + if serv_enabled: + connection.configuration_hints["telemetry.enabled"] = True + connection.telemetry(api) + connection.send_all() + + if serv_enabled and not driver_disabled: + tag, fields = socket.pop_message() + assert tag == b"\x54" + assert fields == [int(api)] + else: + with pytest.raises(OSError): + socket.pop_message() + + +@pytest.mark.parametrize( + ("hints", "valid"), + ( + ({"connection.recv_timeout_seconds": 1}, True), + ({"connection.recv_timeout_seconds": 42}, True), + ({}, True), + ({"whatever_this_is": "ignore me!"}, True), + ({"connection.recv_timeout_seconds": -1}, False), + ({"connection.recv_timeout_seconds": 0}, False), + ({"connection.recv_timeout_seconds": 2.5}, False), + ({"connection.recv_timeout_seconds": None}, False), + ({"connection.recv_timeout_seconds": False}, False), + ({"connection.recv_timeout_seconds": "1"}, False), + ), +) +@mark_sync_test +def test_hint_recv_timeout_seconds( + fake_socket_pair, hints, valid, caplog, mocker +): + address = neo4j.Address(("127.0.0.1", 7687)) + sockets = fake_socket_pair( + address, + packer_cls=Bolt5x7.PACKER_CLS, + unpacker_cls=Bolt5x7.UNPACKER_CLS, + ) + sockets.client.settimeout = mocker.Mock() + sockets.server.send_message( + b"\x70", {"server": "Neo4j/4.3.4", "hints": hints} + ) + sockets.server.send_message(b"\x70", {}) + connection = Bolt5x7( + address, sockets.client, PoolConfig.max_connection_lifetime + ) + with caplog.at_level(logging.INFO): + connection.hello() + if valid: + if "connection.recv_timeout_seconds" in hints: + sockets.client.settimeout.assert_called_once_with( + hints["connection.recv_timeout_seconds"] + ) + else: + sockets.client.settimeout.assert_not_called() + assert not any( + "recv_timeout_seconds" in msg and "invalid" in msg + for msg in caplog.messages + ) + else: + sockets.client.settimeout.assert_not_called() + assert any( + repr(hints["connection.recv_timeout_seconds"]) in msg + and "recv_timeout_seconds" in msg + and "invalid" in msg + for msg in caplog.messages + ) + + +CREDENTIALS = "+++super-secret-sauce+++" + + +@pytest.mark.parametrize( + "auth", + ( + ("user", CREDENTIALS), + neo4j.basic_auth("user", CREDENTIALS), + neo4j.kerberos_auth(CREDENTIALS), + neo4j.bearer_auth(CREDENTIALS), + neo4j.custom_auth("user", CREDENTIALS, "realm", "scheme"), + neo4j.Auth("scheme", "principal", CREDENTIALS, "realm", foo="bar"), + ), +) +@mark_sync_test +def test_credentials_are_not_logged(auth, fake_socket_pair, caplog): + address = neo4j.Address(("127.0.0.1", 7687)) + sockets = fake_socket_pair( + address, + packer_cls=Bolt5x7.PACKER_CLS, + unpacker_cls=Bolt5x7.UNPACKER_CLS, + ) + sockets.server.send_message(b"\x70", {"server": "Neo4j/4.3.4"}) + sockets.server.send_message(b"\x70", {}) + connection = Bolt5x7( + address, + sockets.client, + PoolConfig.max_connection_lifetime, + auth=auth, + ) + with caplog.at_level(logging.DEBUG): + connection.hello() + + if isinstance(auth, tuple): + auth = neo4j.basic_auth(*auth) + for field in ("scheme", "principal", "realm", "parameters"): + value = getattr(auth, field, None) + if value: + assert repr(value) in caplog.text + assert CREDENTIALS not in caplog.text + + +def _assert_notifications_in_extra(extra, expected): + for key in expected: + assert key in extra + assert extra[key] == expected[key] + + +@pytest.mark.parametrize( + ("method", "args", "extra_idx"), + ( + ("run", ("RETURN 1",), 2), + ("begin", (), 0), + ), +) +@pytest.mark.parametrize( + ("cls_min_sev", "method_min_sev"), + itertools.product((None, "WARNING", "OFF"), repeat=2), +) +@pytest.mark.parametrize( + ("cls_dis_clss", "method_dis_clss"), + itertools.product((None, [], ["HINT"], ["HINT", "DEPRECATION"]), repeat=2), +) +@mark_sync_test +def test_supports_notification_filters( + fake_socket, + method, + args, + extra_idx, + cls_min_sev, + method_min_sev, + cls_dis_clss, + method_dis_clss, +): + address = neo4j.Address(("127.0.0.1", 7687)) + socket = fake_socket(address, Bolt5x7.UNPACKER_CLS) + connection = Bolt5x7( + address, + socket, + PoolConfig.max_connection_lifetime, + notifications_min_severity=cls_min_sev, + notifications_disabled_classifications=cls_dis_clss, + ) + method = getattr(connection, method) + + method( + *args, + notifications_min_severity=method_min_sev, + notifications_disabled_classifications=method_dis_clss, + ) + connection.send_all() + + _, fields = socket.pop_message() + extra = fields[extra_idx] + expected = {} + if method_min_sev is not None: + expected["notifications_minimum_severity"] = method_min_sev + if method_dis_clss is not None: + expected["notifications_disabled_classifications"] = method_dis_clss + _assert_notifications_in_extra(extra, expected) + + +@pytest.mark.parametrize("min_sev", (None, "WARNING", "OFF")) +@pytest.mark.parametrize( + "dis_clss", (None, [], ["HINT"], ["HINT", "DEPRECATION"]) +) +@mark_sync_test +def test_hello_supports_notification_filters( + fake_socket_pair, min_sev, dis_clss +): + address = neo4j.Address(("127.0.0.1", 7687)) + sockets = fake_socket_pair( + address, + packer_cls=Bolt5x7.PACKER_CLS, + unpacker_cls=Bolt5x7.UNPACKER_CLS, + ) + sockets.server.send_message(b"\x70", {"server": "Neo4j/1.2.3"}) + sockets.server.send_message(b"\x70", {}) + connection = Bolt5x7( + address, + sockets.client, + PoolConfig.max_connection_lifetime, + notifications_min_severity=min_sev, + notifications_disabled_classifications=dis_clss, + ) + + connection.hello() + + _tag, fields = sockets.server.pop_message() + extra = fields[0] + expected = {} + if min_sev is not None: + expected["notifications_minimum_severity"] = min_sev + if dis_clss is not None: + expected["notifications_disabled_classifications"] = dis_clss + _assert_notifications_in_extra(extra, expected) + + +@mark_sync_test +@pytest.mark.parametrize( + "user_agent", (None, "test user agent", "", USER_AGENT) +) +def test_user_agent(fake_socket_pair, user_agent): + address = neo4j.Address(("127.0.0.1", 7687)) + sockets = fake_socket_pair( + address, + packer_cls=Bolt5x7.PACKER_CLS, + unpacker_cls=Bolt5x7.UNPACKER_CLS, + ) + sockets.server.send_message(b"\x70", {"server": "Neo4j/1.2.3"}) + sockets.server.send_message(b"\x70", {}) + max_connection_lifetime = 0 + connection = Bolt5x7( + address, sockets.client, max_connection_lifetime, user_agent=user_agent + ) + connection.hello() + + _tag, fields = sockets.server.pop_message() + extra = fields[0] + if not user_agent: + assert extra["user_agent"] == USER_AGENT + else: + assert extra["user_agent"] == user_agent + + +@mark_sync_test +@pytest.mark.parametrize( + "user_agent", (None, "test user agent", "", USER_AGENT) +) +def test_sends_bolt_agent(fake_socket_pair, user_agent): + address = neo4j.Address(("127.0.0.1", 7687)) + sockets = fake_socket_pair( + address, + packer_cls=Bolt5x7.PACKER_CLS, + unpacker_cls=Bolt5x7.UNPACKER_CLS, + ) + sockets.server.send_message(b"\x70", {"server": "Neo4j/1.2.3"}) + sockets.server.send_message(b"\x70", {}) + max_connection_lifetime = 0 + connection = Bolt5x7( + address, sockets.client, max_connection_lifetime, user_agent=user_agent + ) + connection.hello() + + _tag, fields = sockets.server.pop_message() + extra = fields[0] + assert extra["bolt_agent"] == BOLT_AGENT_DICT + + +@mark_sync_test +@pytest.mark.parametrize( + ("func", "args", "extra_idx"), + ( + ("run", ("RETURN 1",), 2), + ("begin", (), 0), + ), +) +@pytest.mark.parametrize( + ("timeout", "res"), + ( + (None, None), + (0, 0), + (0.1, 100), + (0.001, 1), + (1e-15, 1), + (0.0005, 1), + (0.0001, 1), + (1.0015, 1002), + (1.000499, 1000), + (1.0025, 1002), + (3.0005, 3000), + (3.456, 3456), + (1, 1000), + ( + -1e-15, + ValueError("Timeout must be a positive number or 0"), + ), + ( + "foo", + ValueError("Timeout must be specified as a number of seconds"), + ), + ( + [1, 2], + TypeError("Timeout must be specified as a number of seconds"), + ), + ), +) +def test_tx_timeout( + fake_socket_pair, func, args, extra_idx, timeout, res +): + address = neo4j.Address(("127.0.0.1", 7687)) + sockets = fake_socket_pair( + address, + packer_cls=Bolt5x7.PACKER_CLS, + unpacker_cls=Bolt5x7.UNPACKER_CLS, + ) + sockets.server.send_message(b"\x70", {}) + connection = Bolt5x7(address, sockets.client, 0) + func = getattr(connection, func) + if isinstance(res, Exception): + with pytest.raises(type(res), match=str(res)): + func(*args, timeout=timeout) + else: + func(*args, timeout=timeout) + connection.send_all() + _tag, fields = sockets.server.pop_message() + extra = fields[extra_idx] + if timeout is None: + assert "tx_timeout" not in extra + else: + assert extra["tx_timeout"] == res + + +@pytest.mark.parametrize( + "actions", + itertools.combinations_with_replacement( + itertools.product( + ("run", "begin", "begin_run"), + ("reset", "commit", "rollback"), + (None, "some_db", "another_db"), + ), + 2, + ), +) +@mark_sync_test +def test_tracks_last_database(fake_socket_pair, actions): + address = neo4j.Address(("127.0.0.1", 7687)) + sockets = fake_socket_pair( + address, + packer_cls=Bolt5x7.PACKER_CLS, + unpacker_cls=Bolt5x7.UNPACKER_CLS, + ) + connection = Bolt5x7(address, sockets.client, 0) + sockets.server.send_message(b"\x70", {"server": "Neo4j/1.2.3"}) + sockets.server.send_message(b"\x70", {}) + connection.hello() + assert connection.last_database is None + for action, finish, db in actions: + sockets.server.send_message(b"\x70", {}) + if action == "run": + connection.run("RETURN 1", db=db) + elif action == "begin": + connection.begin(db=db) + elif action == "begin_run": + connection.begin(db=db) + assert connection.last_database == db + sockets.server.send_message(b"\x70", {}) + connection.run("RETURN 1") + else: + raise ValueError(action) + + assert connection.last_database == db + connection.send_all() + connection.fetch_all() + assert connection.last_database == db + + sockets.server.send_message(b"\x70", {}) + if finish == "reset": + connection.reset() + elif finish == "commit": + if action == "run": + connection.pull() + else: + connection.commit() + elif finish == "rollback": + if action == "run": + connection.pull() + else: + connection.rollback() + else: + raise ValueError(finish) + + connection.send_all() + connection.fetch_all() + + assert connection.last_database == db + + +DEFAULT_DIAG_REC_PAIRS = ( + ("OPERATION", ""), + ("OPERATION_CODE", "0"), + ("CURRENT_SCHEMA", "/"), +) + + +@pytest.mark.parametrize( + "sent_diag_records", + powerset( + ( + ..., + None, + {}, + [], + "1", + 1, + {"OPERATION_CODE": "0"}, + {"OPERATION": "", "OPERATION_CODE": "0", "CURRENT_SCHEMA": "/"}, + {"OPERATION": "Foo", "OPERATION_CODE": 1, "CURRENT_SCHEMA": False}, + {"OPERATION": "", "OPERATION_CODE": "0", "bar": "baz"}, + ), + upper_limit=3, + ), +) +@pytest.mark.parametrize("method", ("pull", "discard")) +@mark_sync_test +def test_enriches_statuses( + sent_diag_records, + method, + fake_socket_pair, +): + address = neo4j.Address(("127.0.0.1", 7687)) + sockets = fake_socket_pair( + address, + packer_cls=Bolt5x7.PACKER_CLS, + unpacker_cls=Bolt5x7.UNPACKER_CLS, + ) + connection = Bolt5x7(address, sockets.client, 0) + + sent_metadata = { + "statuses": [ + { + "status_description": "the status description", + "description": "description", + "diagnostic_record": r, + } + if r is not ... + else { + "status_description": "the status description", + "description": "description", + } + for r in sent_diag_records + ] + } + sockets.server.send_message(b"\x70", sent_metadata) + + received_metadata = None + + def on_success(metadata): + nonlocal received_metadata + received_metadata = metadata + + getattr(connection, method)(on_success=on_success) + connection.send_all() + connection.fetch_all() + + def extend_diag_record(r): + if r is ...: + return dict(DEFAULT_DIAG_REC_PAIRS) + if isinstance(r, dict): + return dict((*DEFAULT_DIAG_REC_PAIRS, *r.items())) + return r + + expected_diag_records = [extend_diag_record(r) for r in sent_diag_records] + expected_metadata = { + "statuses": [ + { + "status_description": "the status description", + "description": "description", + "diagnostic_record": r, + } + if r is not ... + else { + "status_description": "the status description", + "description": "description", + } + for r in expected_diag_records + ] + } + + assert received_metadata == expected_metadata + + +@pytest.mark.parametrize( + "sent_diag_records", + powerset( + ( + ..., + None, + {}, + [], + "1", + 1, + {"OPERATION_CODE": "0"}, + {"OPERATION": "", "OPERATION_CODE": "0", "CURRENT_SCHEMA": "/"}, + {"OPERATION": "Foo", "OPERATION_CODE": 1, "CURRENT_SCHEMA": False}, + {"OPERATION": "", "OPERATION_CODE": "0", "bar": "baz"}, + ), + lower_limit=1, + upper_limit=3, + ), +) +@mark_sync_test +def test_enriches_error_statuses( + sent_diag_records, + fake_socket_pair, +): + address = neo4j.Address(("127.0.0.1", 7687)) + sockets = fake_socket_pair( + address, + packer_cls=Bolt5x7.PACKER_CLS, + unpacker_cls=Bolt5x7.UNPACKER_CLS, + ) + connection = Bolt5x7(address, sockets.client, 0) + sent_diag_records = [ + {**r, "_classification": "CLIENT_ERROR", "_status_parameters": {}} + if isinstance(r, dict) + else r + for r in sent_diag_records + ] + + sent_metadata = _build_error_hierarchy_metadata(sent_diag_records) + + sockets.server.send_message(b"\x7f", sent_metadata) + + received_metadata = None + + def on_failure(metadata): + nonlocal received_metadata + received_metadata = metadata + + connection.run("RETURN 1", on_failure=on_failure) + connection.send_all() + with pytest.raises(Neo4jError): + connection.fetch_all() + + def extend_diag_record(r): + if r is ...: + return dict(DEFAULT_DIAG_REC_PAIRS) + if isinstance(r, dict): + return dict((*DEFAULT_DIAG_REC_PAIRS, *r.items())) + return r + + expected_diag_records = [extend_diag_record(r) for r in sent_diag_records] + expected_metadata = _build_error_hierarchy_metadata(expected_diag_records) + + assert received_metadata == expected_metadata + + +def _build_error_hierarchy_metadata(diag_records_metadata): + metadata = { + "gql_status": "FOO12", + "description": "but have you tried not doing that?!", + "message": "some people just can't be helped", + "neo4j_code": "Neo.ClientError.Generic.YouSuck", + } + if diag_records_metadata[0] is not ...: + metadata["diagnostic_record"] = diag_records_metadata[0] + current_root = metadata + for i, r in enumerate(diag_records_metadata[1:]): + current_root["cause"] = { + "description": f"error cause nr. {i + 1}", + "message": f"cause message {i + 1}", + } + current_root = current_root["cause"] + if r is not ...: + current_root["diagnostic_record"] = r + return metadata diff --git a/tests/unit/sync/io/test_neo4j_pool.py b/tests/unit/sync/io/test_neo4j_pool.py index edc777d0..89b4d16b 100644 --- a/tests/unit/sync/io/test_neo4j_pool.py +++ b/tests/unit/sync/io/test_neo4j_pool.py @@ -535,11 +535,13 @@ def test_passes_pool_config_to_connection(mocker): "error", ( ServiceUnavailable(), - Neo4jError.hydrate( - "message", "Neo.ClientError.Statement.EntityNotFound" + Neo4jError._hydrate_neo4j( + code="Neo.ClientError.Statement.EntityNotFound", + message="message", ), - Neo4jError.hydrate( - "message", "Neo.ClientError.Security.AuthorizationExpired" + Neo4jError._hydrate_neo4j( + code="Neo.ClientError.Security.AuthorizationExpired", + message="message", ), ), ) @@ -578,20 +580,20 @@ def test_discovery_is_retried(custom_routing_opener, error): @pytest.mark.parametrize( "error", map( - lambda args: Neo4jError.hydrate(*args), + lambda args: Neo4jError._hydrate_neo4j(code=args[0], message=args[1]), ( - ("message", "Neo.ClientError.Database.DatabaseNotFound"), - ("message", "Neo.ClientError.Transaction.InvalidBookmark"), - ("message", "Neo.ClientError.Transaction.InvalidBookmarkMixture"), - ("message", "Neo.ClientError.Statement.TypeError"), - ("message", "Neo.ClientError.Statement.ArgumentError"), - ("message", "Neo.ClientError.Request.Invalid"), - ("message", "Neo.ClientError.Security.AuthenticationRateLimit"), - ("message", "Neo.ClientError.Security.CredentialsExpired"), - ("message", "Neo.ClientError.Security.Forbidden"), - ("message", "Neo.ClientError.Security.TokenExpired"), - ("message", "Neo.ClientError.Security.Unauthorized"), - ("message", "Neo.ClientError.Security.MadeUpError"), + ("Neo.ClientError.Database.DatabaseNotFound", "message"), + ("Neo.ClientError.Transaction.InvalidBookmark", "message"), + ("Neo.ClientError.Transaction.InvalidBookmarkMixture", "message"), + ("Neo.ClientError.Statement.TypeError", "message"), + ("Neo.ClientError.Statement.ArgumentError", "message"), + ("Neo.ClientError.Request.Invalid", "message"), + ("Neo.ClientError.Security.AuthenticationRateLimit", "message"), + ("Neo.ClientError.Security.CredentialsExpired", "message"), + ("Neo.ClientError.Security.Forbidden", "message"), + ("Neo.ClientError.Security.TokenExpired", "message"), + ("Neo.ClientError.Security.Unauthorized", "message"), + ("Neo.ClientError.Security.MadeUpError", "message"), ), ), ) @@ -627,7 +629,7 @@ def test_fast_failing_discovery(custom_routing_opener, error): @pytest.mark.parametrize( ("error", "marks_unauthenticated", "fetches_new"), ( - (Neo4jError.hydrate("message", args[0]), *args[1:]) + (Neo4jError._hydrate_neo4j(code=args[0], message="message"), *args[1:]) for args in ( ("Neo.ClientError.Database.DatabaseNotFound", False, False), ("Neo.ClientError.Statement.TypeError", False, False), diff --git a/tests/unit/sync/test_auth_management.py b/tests/unit/sync/test_auth_management.py index da90a940..34eb4d72 100644 --- a/tests/unit/sync/test_auth_management.py +++ b/tests/unit/sync/test_auth_management.py @@ -59,7 +59,7 @@ "Neo.ClientError.Security.Unauthorized", } SAMPLE_ERRORS = [ - Neo4jError.hydrate(code=code) + Neo4jError._hydrate_neo4j(code=code) for code in { "Neo.ClientError.Security.AuthenticationRateLimit", "Neo.ClientError.Security.AuthorizationExpired", diff --git a/tests/unit/sync/work/test_transaction.py b/tests/unit/sync/work/test_transaction.py index 97f31409..53aeba1e 100644 --- a/tests/unit/sync/work/test_transaction.py +++ b/tests/unit/sync/work/test_transaction.py @@ -315,7 +315,12 @@ def test_server_error_propagates(scripted_connection, error): ( "pull", { - "on_failure": ({"code": "Neo.ClientError.Made.Up"},), + "on_failure": ( + { + "neo4j_code": "Neo.ClientError.Made.Up", + "gql_status": "50N42", + }, + ), "on_summary": None, }, )