Skip to content

Raise hydration errors on value access #759

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 3 commits into from
Jul 18, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions docs/source/api.rst
Original file line number Diff line number Diff line change
Expand Up @@ -1323,6 +1323,8 @@ Client-side errors

* :class:`neo4j.exceptions.ResultNotSingleError`

* :class:`neo4j.exceptions.BrokenRecordError`

* :class:`neo4j.exceptions.SessionExpired`

* :class:`neo4j.exceptions.ServiceUnavailable`
Expand Down Expand Up @@ -1360,6 +1362,9 @@ Client-side errors
.. autoclass:: neo4j.exceptions.ResultNotSingleError
:show-inheritance:

.. autoclass:: neo4j.exceptions.BrokenRecordError
:show-inheritance:

.. autoclass:: neo4j.exceptions.SessionExpired
:show-inheritance:

Expand Down
4 changes: 0 additions & 4 deletions neo4j/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,7 +108,6 @@
"BoltDriver",
"Bookmark",
"Bookmarks",
"Config",
"custom_auth",
"DEFAULT_DATABASE",
"Driver",
Expand All @@ -120,15 +119,13 @@
"kerberos_auth",
"ManagedTransaction",
"Neo4jDriver",
"PoolConfig",
"Query",
"READ_ACCESS",
"Record",
"Result",
"ResultSummary",
"ServerInfo",
"Session",
"SessionConfig",
"SummaryCounters",
"Transaction",
"TRUST_ALL_CERTIFICATES",
Expand All @@ -138,7 +135,6 @@
"TrustSystemCAs",
"unit_of_work",
"Version",
"WorkspaceConfig",
"WRITE_ACCESS",
]

Expand Down
6 changes: 6 additions & 0 deletions neo4j/_async/work/result.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
from warnings import warn

from ..._async_compat.util import AsyncUtil
from ..._codec.hydration import BrokenHydrationObject
from ..._data import (
Record,
RecordTableRowExporter,
Expand Down Expand Up @@ -145,6 +146,11 @@ async def on_failed_attach(metadata):
def _pull(self):
def on_records(records):
if not self._discarding:
records = (
record.raw_data
if isinstance(record, BrokenHydrationObject) else record
for record in records
)
self._record_buffer.extend((
Record(zip(self._keys, record))
for record in records
Expand Down
6 changes: 5 additions & 1 deletion neo4j/_codec/hydration/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,11 +15,15 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from ._common import HydrationScope
from ._common import (
BrokenHydrationObject,
HydrationScope,
)
from ._interface import HydrationHandlerABC


__all__ = [
"BrokenHydrationObject",
"HydrationHandlerABC",
"HydrationScope",
]
57 changes: 53 additions & 4 deletions neo4j/_codec/hydration/_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,18 +16,46 @@
# limitations under the License.


from copy import copy

from ...graph import Graph
from ..packstream import Structure


class BrokenHydrationObject:
"""
Represents an object from the server, not understood by the driver.

A :class:`neo4j.Record` might contain a ``BrokenHydrationObject`` object
if the driver received data from the server that it did not understand.
This can for instance happen if the server sends a zoned datetime with a
zone name unknown to the driver.

There is no need to explicitly check for this type. Any method on the
:class:`neo4j.Record` that would return a ``BrokenHydrationObject``, will
raise a :exc:`neo4j.exceptions.BrokenRecordError`
with the original exception as cause.
"""

def __init__(self, error, raw_data):
self.error = error
"The exception raised while decoding the received object."
self.raw_data = raw_data
"""The raw data that caused the exception."""

def exception_copy(self):
exc_copy = copy(self.error)
exc_copy.with_traceback(self.error.__traceback__)
return exc_copy


class GraphHydrator:
def __init__(self):
self.graph = Graph()
self.struct_hydration_functions = {}


class HydrationScope:

def __init__(self, hydration_handler, graph_hydrator):
self._hydration_handler = hydration_handler
self._graph_hydrator = graph_hydrator
Expand All @@ -37,14 +65,35 @@ def __init__(self, hydration_handler, graph_hydrator):
}
self.hydration_hooks = {
Structure: self._hydrate_structure,
list: self._hydrate_list,
dict: self._hydrate_dict,
}
self.dehydration_hooks = hydration_handler.dehydration_functions

def _hydrate_structure(self, value):
f = self._struct_hydration_functions.get(value.tag)
if not f:
return value
return f(*value.fields)
try:
if not f:
raise ValueError(
f"Protocol error: unknown Structure tag: {value.tag!r}"
)
return f(*value.fields)
except Exception as e:
return BrokenHydrationObject(e, value)

@staticmethod
def _hydrate_list(value):
for v in value:
if isinstance(v, BrokenHydrationObject):
return BrokenHydrationObject(v.error, value)
return value

@staticmethod
def _hydrate_dict(value):
for v in value.values():
if isinstance(v, BrokenHydrationObject):
return BrokenHydrationObject(v.error, value)
return value

def get_graph(self):
return self._graph_hydrator.graph
41 changes: 35 additions & 6 deletions neo4j/_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,10 @@
from functools import reduce
from operator import xor as xor_operator

from ._codec.hydration import BrokenHydrationObject
from ._conf import iter_items
from ._meta import deprecated
from .exceptions import BrokenRecordError
from .graph import (
Node,
Path,
Expand All @@ -55,9 +58,26 @@ def __new__(cls, iterable=()):
inst.__keys = tuple(keys)
return inst

def _broken_record_error(self, index):
return BrokenRecordError(
f"Record contains broken data at {index} ('{self.__keys[index]}')"
)

def _super_getitem_single(self, index):
value = super().__getitem__(index)
if isinstance(value, BrokenHydrationObject):
raise self._broken_record_error(index) from value.error
return value

def __repr__(self):
return "<%s %s>" % (self.__class__.__name__,
" ".join("%s=%r" % (field, self[i]) for i, field in enumerate(self.__keys)))
return "<%s %s>" % (
self.__class__.__name__,
" ".join("%s=%r" % (field, value)
for field, value in zip(self.__keys, super().__iter__()))
)

def __str__(self):
return self.__repr__()

def __eq__(self, other):
""" In order to be flexible regarding comparison, the equality rules
Expand All @@ -83,18 +103,26 @@ def __ne__(self, other):
def __hash__(self):
return reduce(xor_operator, map(hash, self.items()))

def __iter__(self):
for i, v in enumerate(super().__iter__()):
if isinstance(v, BrokenHydrationObject):
raise self._broken_record_error(i) from v.error
yield v

def __getitem__(self, key):
if isinstance(key, slice):
keys = self.__keys[key]
values = super(Record, self).__getitem__(key)
values = super().__getitem__(key)
return self.__class__(zip(keys, values))
try:
index = self.index(key)
except IndexError:
return None
else:
return super(Record, self).__getitem__(index)
return self._super_getitem_single(index)

# TODO: 6.0 - remove
@deprecated("This method is deprecated and will be removed in the future.")
def __getslice__(self, start, stop):
key = slice(start, stop)
keys = self.__keys[key]
Expand All @@ -114,7 +142,7 @@ def get(self, key, default=None):
except ValueError:
return default
if 0 <= index < len(self):
return super(Record, self).__getitem__(index)
return self._super_getitem_single(index)
else:
return default

Expand Down Expand Up @@ -197,7 +225,8 @@ def items(self, *keys):
else:
d.append((self.__keys[i], self[i]))
return d
return list((self.__keys[i], super(Record, self).__getitem__(i)) for i in range(len(self)))
return list((self.__keys[i], self._super_getitem_single(i))
for i in range(len(self)))

def data(self, *keys):
""" Return the keys and values of this record as a dictionary,
Expand Down
6 changes: 6 additions & 0 deletions neo4j/_sync/work/result.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
from warnings import warn

from ..._async_compat.util import Util
from ..._codec.hydration import BrokenHydrationObject
from ..._data import (
Record,
RecordTableRowExporter,
Expand Down Expand Up @@ -145,6 +146,11 @@ def on_failed_attach(metadata):
def _pull(self):
def on_records(records):
if not self._discarding:
records = (
record.raw_data
if isinstance(record, BrokenHydrationObject) else record
for record in records
)
self._record_buffer.extend((
Record(zip(self._keys, record))
for record in records
Expand Down
10 changes: 10 additions & 0 deletions neo4j/exceptions.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@
+ ResultError
+ ResultConsumedError
+ ResultNotSingleError
+ BrokenRecordError
+ SessionExpired
+ ServiceUnavailable
+ RoutingServiceUnavailable
Expand Down Expand Up @@ -395,6 +396,15 @@ class ResultNotSingleError(ResultError):
"""Raised when a result should have exactly one record but does not."""


# DriverError > BrokenRecordError
class BrokenRecordError(DriverError):
""" Raised when accessing a Record's field that couldn't be decoded.

This can for instance happen when the server sends a zoned datetime with a
zone id unknown to the client.
"""


# DriverError > SessionExpired
class SessionExpired(DriverError):
""" Raised when a session is no longer able to fulfil
Expand Down
23 changes: 17 additions & 6 deletions testkitbackend/_async/backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,6 +97,21 @@ def _exc_stems_from_driver(exc):
if DRIVER_PATH in p.parents:
return True

@staticmethod
def _exc_msg(exc, max_depth=10):
if isinstance(exc, Neo4jError) and exc.message is not None:
return str(exc.message)

depth = 0
res = str(exc)
while getattr(exc, "__cause__", None) is not None:
depth += 1
if depth >= max_depth:
break
res += f"\nCaused by: {exc.__cause__!r}"
exc = exc.__cause__
return res

async def write_driver_exc(self, exc):
log.debug(traceback.format_exc())

Expand All @@ -109,14 +124,10 @@ async def write_driver_exc(self, exc):
wrapped_exc = exc.wrapped_exc
payload["errorType"] = str(type(wrapped_exc))
if wrapped_exc.args:
payload["msg"] = str(wrapped_exc.args[0])
payload["msg"] = self._exc_msg(wrapped_exc.args[0])
else:
payload["errorType"] = str(type(exc))
if isinstance(exc, Neo4jError) and exc.message is not None:
payload["msg"] = str(exc.message)
elif exc.args:
payload["msg"] = str(exc.args[0])

payload["msg"] = self._exc_msg(exc)
if isinstance(exc, Neo4jError):
payload["code"] = exc.code

Expand Down
23 changes: 17 additions & 6 deletions testkitbackend/_sync/backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,6 +97,21 @@ def _exc_stems_from_driver(exc):
if DRIVER_PATH in p.parents:
return True

@staticmethod
def _exc_msg(exc, max_depth=10):
if isinstance(exc, Neo4jError) and exc.message is not None:
return str(exc.message)

depth = 0
res = str(exc)
while getattr(exc, "__cause__", None) is not None:
depth += 1
if depth >= max_depth:
break
res += f"\nCaused by: {exc.__cause__!r}"
exc = exc.__cause__
return res

def write_driver_exc(self, exc):
log.debug(traceback.format_exc())

Expand All @@ -109,14 +124,10 @@ def write_driver_exc(self, exc):
wrapped_exc = exc.wrapped_exc
payload["errorType"] = str(type(wrapped_exc))
if wrapped_exc.args:
payload["msg"] = str(wrapped_exc.args[0])
payload["msg"] = self._exc_msg(wrapped_exc.args[0])
else:
payload["errorType"] = str(type(exc))
if isinstance(exc, Neo4jError) and exc.message is not None:
payload["msg"] = str(exc.message)
elif exc.args:
payload["msg"] = str(exc.args[0])

payload["msg"] = self._exc_msg(exc)
if isinstance(exc, Neo4jError):
payload["code"] = exc.code

Expand Down
Loading