From e63d2a27dc1a4ab0bcc6554e34760d056d77586b Mon Sep 17 00:00:00 2001 From: Robsdedude Date: Thu, 5 Oct 2023 14:48:15 +0200 Subject: [PATCH 1/3] Propagate errors across all results in a transaction A new error `ResultFailedError` is introduced. It will be raised when using a `Result` object after the result or another result in the same transaction has failed. User code would only ever run into this situation when catch exceptions and deciding to ignore them. Now, an error will be raised instead of undefined behavior. The undefined behavior before this fix could be (among other things) protocol violations, incomplete summary data, and hard to interpret errors. --- docs/source/api.rst | 5 +++++ src/neo4j/_async/work/result.py | 25 ++++++++++++++++++++++++- src/neo4j/_async/work/transaction.py | 2 ++ src/neo4j/_sync/work/result.py | 25 ++++++++++++++++++++++++- src/neo4j/_sync/work/transaction.py | 2 ++ src/neo4j/exceptions.py | 12 ++++++++++++ testkitbackend/test_config.json | 8 +------- 7 files changed, 70 insertions(+), 9 deletions(-) diff --git a/docs/source/api.rst b/docs/source/api.rst index 34ba14e78..ca23b4fb2 100644 --- a/docs/source/api.rst +++ b/docs/source/api.rst @@ -1902,6 +1902,8 @@ Client-side errors * :class:`neo4j.exceptions.ResultError` + * :class:`neo4j.exceptions.ResultFailedError` + * :class:`neo4j.exceptions.ResultConsumedError` * :class:`neo4j.exceptions.ResultNotSingleError` @@ -1946,6 +1948,9 @@ Client-side errors :show-inheritance: :members: result +.. autoexception:: neo4j.exceptions.ResultFailedError() + :show-inheritance: + .. autoexception:: neo4j.exceptions.ResultConsumedError() :show-inheritance: diff --git a/src/neo4j/_async/work/result.py b/src/neo4j/_async/work/result.py index fe87d6cd4..4824771ef 100644 --- a/src/neo4j/_async/work/result.py +++ b/src/neo4j/_async/work/result.py @@ -38,6 +38,7 @@ ) from ...exceptions import ( ResultConsumedError, + ResultFailedError, ResultNotSingleError, ) from ...time import ( @@ -57,6 +58,10 @@ _TResultKey = t.Union[int, str] +_RESULT_FAILED_ERROR = ( + "The result has failed. Either this result or another result in the same" + "transaction has encountered an error." +) _RESULT_OUT_OF_SCOPE_ERROR = ( "The result is out of scope. The associated transaction " "has been closed. Results can only be used while the " @@ -76,8 +81,11 @@ class AsyncResult: """ def __init__(self, connection, fetch_size, on_closed, on_error): - self._connection = ConnectionErrorHandler(connection, on_error) + self._connection = ConnectionErrorHandler( + connection, self._connection_error_handler + ) self._hydration_scope = connection.new_hydration_scope() + self._on_error = on_error self._on_closed = on_closed self._metadata = None self._keys = None @@ -101,6 +109,13 @@ def __init__(self, connection, fetch_size, on_closed, on_error): self._consumed = False # the result has been closed as a result of closing the transaction self._out_of_scope = False + # exception shared across all results of a transaction + self._exception = None + + async def _connection_error_handler(self, exc): + self._exception = exc + self._attached = False + await AsyncUtil.callback(self._on_error, exc) @property def _qid(self): @@ -257,6 +272,9 @@ async def __aiter__(self) -> t.AsyncIterator[Record]: await self._connection.send_all() self._exhausted = True + if self._exception is not None: + raise ResultFailedError(self, _RESULT_FAILED_ERROR) \ + from self._exception if self._out_of_scope: raise ResultConsumedError(self, _RESULT_OUT_OF_SCOPE_ERROR) if self._consumed: @@ -346,6 +364,11 @@ async def _tx_end(self): await self._exhaust() self._out_of_scope = True + def _tx_failure(self, exc): + # Handle failure of the associated transaction. + self._attached = False + self._exception = exc + async def consume(self) -> ResultSummary: """Consume the remainder of this result and return a :class:`neo4j.ResultSummary`. diff --git a/src/neo4j/_async/work/transaction.py b/src/neo4j/_async/work/transaction.py index 42eeac7f1..f009bf7ca 100644 --- a/src/neo4j/_async/work/transaction.py +++ b/src/neo4j/_async/work/transaction.py @@ -92,6 +92,8 @@ async def _result_on_closed_handler(self): async def _error_handler(self, exc): self._last_error = exc + for result in self._results: + result._tx_failure(exc) if isinstance(exc, asyncio.CancelledError): self._cancel() return diff --git a/src/neo4j/_sync/work/result.py b/src/neo4j/_sync/work/result.py index be3ea2e35..0153d4bee 100644 --- a/src/neo4j/_sync/work/result.py +++ b/src/neo4j/_sync/work/result.py @@ -38,6 +38,7 @@ ) from ...exceptions import ( ResultConsumedError, + ResultFailedError, ResultNotSingleError, ) from ...time import ( @@ -57,6 +58,10 @@ _TResultKey = t.Union[int, str] +_RESULT_FAILED_ERROR = ( + "The result has failed. Either this result or another result in the same" + "transaction has encountered an error." +) _RESULT_OUT_OF_SCOPE_ERROR = ( "The result is out of scope. The associated transaction " "has been closed. Results can only be used while the " @@ -76,8 +81,11 @@ class Result: """ def __init__(self, connection, fetch_size, on_closed, on_error): - self._connection = ConnectionErrorHandler(connection, on_error) + self._connection = ConnectionErrorHandler( + connection, self._connection_error_handler + ) self._hydration_scope = connection.new_hydration_scope() + self._on_error = on_error self._on_closed = on_closed self._metadata = None self._keys = None @@ -101,6 +109,13 @@ def __init__(self, connection, fetch_size, on_closed, on_error): self._consumed = False # the result has been closed as a result of closing the transaction self._out_of_scope = False + # exception shared across all results of a transaction + self._exception = None + + def _connection_error_handler(self, exc): + self._exception = exc + self._attached = False + Util.callback(self._on_error, exc) @property def _qid(self): @@ -257,6 +272,9 @@ def __iter__(self) -> t.Iterator[Record]: self._connection.send_all() self._exhausted = True + if self._exception is not None: + raise ResultFailedError(self, _RESULT_FAILED_ERROR) \ + from self._exception if self._out_of_scope: raise ResultConsumedError(self, _RESULT_OUT_OF_SCOPE_ERROR) if self._consumed: @@ -346,6 +364,11 @@ def _tx_end(self): self._exhaust() self._out_of_scope = True + def _tx_failure(self, exc): + # Handle failure of the associated transaction. + self._attached = False + self._exception = exc + def consume(self) -> ResultSummary: """Consume the remainder of this result and return a :class:`neo4j.ResultSummary`. diff --git a/src/neo4j/_sync/work/transaction.py b/src/neo4j/_sync/work/transaction.py index 9e1cc8c0b..1eda4faad 100644 --- a/src/neo4j/_sync/work/transaction.py +++ b/src/neo4j/_sync/work/transaction.py @@ -92,6 +92,8 @@ def _result_on_closed_handler(self): def _error_handler(self, exc): self._last_error = exc + for result in self._results: + result._tx_failure(exc) if isinstance(exc, asyncio.CancelledError): self._cancel() return diff --git a/src/neo4j/exceptions.py b/src/neo4j/exceptions.py index 1c366fae8..ac1d3ae93 100644 --- a/src/neo4j/exceptions.py +++ b/src/neo4j/exceptions.py @@ -40,6 +40,7 @@ + TransactionError + TransactionNestingError + ResultError + + ResultFailedError + ResultConsumedError + ResultNotSingleError + BrokenRecordError @@ -464,6 +465,17 @@ def __init__(self, result_, *args, **kwargs): self.result = result_ +# DriverError > ResultError > ResultFailedError +class ResultFailedError(ResultError): + """Raised when trying to access records of a failed result. + + A :class:`.Result` will be considered failed if + * itself encountered an error while fetching records + * another result within the same transaction encountered an error while + fetching records + """ + + # DriverError > ResultError > ResultConsumedError class ResultConsumedError(ResultError): """Raised when trying to access records of a consumed result.""" diff --git a/testkitbackend/test_config.json b/testkitbackend/test_config.json index 5957054e9..92af6d4ab 100644 --- a/testkitbackend/test_config.json +++ b/testkitbackend/test_config.json @@ -13,13 +13,7 @@ "'neo4j.datatypes.test_temporal_types.TestDataTypes.test_should_echo_all_timezone_ids'": "test_subtest_skips.dt_conversion", "'neo4j.datatypes.test_temporal_types.TestDataTypes.test_date_time_cypher_created_tz_id'": - "test_subtest_skips.tz_id", - "'stub.tx_run.test_tx_run.TestTxRun.test_should_prevent_discard_after_tx_termination_on_run'": - "Fixme: transactions don't prevent further actions after failure.", - "'stub.tx_run.test_tx_run.TestTxRun.test_should_prevent_pull_after_tx_termination_on_pull'": - "Fixme: transactions don't prevent further actions after failure.", - "'stub.tx_run.test_tx_run.TestTxRun.test_should_prevent_pull_after_tx_termination_on_run'": - "Fixme: transactions don't prevent further actions after failure." + "test_subtest_skips.tz_id" }, "features": { "Feature:API:BookmarkManager": true, From b03702dac7944d324a4b079f33ab0cbccfc33c6e Mon Sep 17 00:00:00 2001 From: Robsdedude Date: Fri, 6 Oct 2023 11:44:13 +0200 Subject: [PATCH 2/3] Fix missing space in error message --- src/neo4j/_async/work/result.py | 2 +- src/neo4j/_sync/work/result.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/src/neo4j/_async/work/result.py b/src/neo4j/_async/work/result.py index 4824771ef..9e1b5c786 100644 --- a/src/neo4j/_async/work/result.py +++ b/src/neo4j/_async/work/result.py @@ -59,7 +59,7 @@ _RESULT_FAILED_ERROR = ( - "The result has failed. Either this result or another result in the same" + "The result has failed. Either this result or another result in the same " "transaction has encountered an error." ) _RESULT_OUT_OF_SCOPE_ERROR = ( diff --git a/src/neo4j/_sync/work/result.py b/src/neo4j/_sync/work/result.py index 0153d4bee..b2a079962 100644 --- a/src/neo4j/_sync/work/result.py +++ b/src/neo4j/_sync/work/result.py @@ -59,7 +59,7 @@ _RESULT_FAILED_ERROR = ( - "The result has failed. Either this result or another result in the same" + "The result has failed. Either this result or another result in the same " "transaction has encountered an error." ) _RESULT_OUT_OF_SCOPE_ERROR = ( From dbd634a783e869dfc0e84bee2a93b1b2ee543584 Mon Sep 17 00:00:00 2001 From: Robsdedude Date: Mon, 9 Oct 2023 11:08:51 +0200 Subject: [PATCH 3/3] Add unit tests for transaction error propagation --- testkitbackend/_async/requests.py | 2 +- testkitbackend/_sync/requests.py | 2 +- tests/unit/async_/fixtures/fake_connection.py | 12 ++++- tests/unit/async_/work/test_transaction.py | 54 ++++++++++++++++++- tests/unit/sync/fixtures/fake_connection.py | 12 ++++- tests/unit/sync/work/test_transaction.py | 54 ++++++++++++++++++- 6 files changed, 130 insertions(+), 6 deletions(-) diff --git a/testkitbackend/_async/requests.py b/testkitbackend/_async/requests.py index 1ad3cbdb2..2038c8baa 100644 --- a/testkitbackend/_async/requests.py +++ b/testkitbackend/_async/requests.py @@ -401,7 +401,7 @@ async def ExecuteQuery(backend, data): def resolution_func(backend, custom_resolver=False, custom_dns_resolver=False): # This solution (putting custom resolution together with DNS resolution - # into one function only works because the Python driver calls the custom + # into one function) only works because the Python driver calls the custom # resolver function for every connection, which is not true for all # drivers. Properly exposing a way to change the DNS lookup behavior is not # possible without changing the driver's code. diff --git a/testkitbackend/_sync/requests.py b/testkitbackend/_sync/requests.py index 9932021e3..efaec61bb 100644 --- a/testkitbackend/_sync/requests.py +++ b/testkitbackend/_sync/requests.py @@ -401,7 +401,7 @@ def ExecuteQuery(backend, data): def resolution_func(backend, custom_resolver=False, custom_dns_resolver=False): # This solution (putting custom resolution together with DNS resolution - # into one function only works because the Python driver calls the custom + # into one function) only works because the Python driver calls the custom # resolver function for every connection, which is not true for all # drivers. Properly exposing a way to change the DNS lookup behavior is not # possible without changing the driver's code. diff --git a/tests/unit/async_/fixtures/fake_connection.py b/tests/unit/async_/fixtures/fake_connection.py index 9e3995cfc..5989b1521 100644 --- a/tests/unit/async_/fixtures/fake_connection.py +++ b/tests/unit/async_/fixtures/fake_connection.py @@ -24,6 +24,7 @@ from neo4j._async.io import AsyncBolt from neo4j._deadline import Deadline from neo4j.auth_management import AsyncAuthManager +from neo4j.exceptions import Neo4jError __all__ = [ @@ -154,10 +155,12 @@ def set_script(self, callbacks): [ ("run", {"on_success": ({},), "on_summary": None}), ("pull", { + "on_records": ([some_record],), "on_success": None, "on_summary": None, - "on_records": }) + # use any exception to throw it instead of calling handlers + ("commit", RuntimeError("oh no!")) ] ``` Note that arguments can be `None`. In this case, ScriptedConnection @@ -180,6 +183,9 @@ def func(*args, **kwargs): self._script_pos += 1 async def callback(): + if isinstance(scripted_callbacks, BaseException): + raise scripted_callbacks + error = None for cb_name, default_cb_args in ( ("on_ignored", ({},)), ("on_failure", ({},)), @@ -197,10 +203,14 @@ async def callback(): if cb_args is None: cb_args = default_cb_args res = cb(*cb_args) + if cb_name == "on_failure": + error = Neo4jError.hydrate(**cb_args[0]) try: await res # maybe the callback is async except TypeError: pass # or maybe it wasn't ;) + if error is not None: + raise error self.callbacks.append(callback) diff --git a/tests/unit/async_/work/test_transaction.py b/tests/unit/async_/work/test_transaction.py index 33238c76f..b7d4d7194 100644 --- a/tests/unit/async_/work/test_transaction.py +++ b/tests/unit/async_/work/test_transaction.py @@ -17,7 +17,6 @@ from unittest.mock import MagicMock -from uuid import uuid4 import pytest @@ -26,6 +25,11 @@ NotificationMinimumSeverity, Query, ) +from neo4j.exceptions import ( + ClientError, + ResultFailedError, + ServiceUnavailable, +) from ...._async_compat import mark_async_test @@ -275,3 +279,51 @@ async def test_transaction_begin_pipelining( expected_calls.append(("send_all",)) expected_calls.append(("fetch_all",)) assert async_fake_connection.method_calls == expected_calls + + +@pytest.mark.parametrize("error", ("server", "connection")) +@mark_async_test +async def test_server_error_propagates(async_scripted_connection, error): + connection = async_scripted_connection + script = [ + # res 1 + ("run", {"on_success": ({"fields": ["n"]},), "on_summary": None}), + ("pull", {"on_records": ([[1], [2]],), + "on_success": ({"has_more": True},)}), + # res 2 + ("run", {"on_success": ({"fields": ["n"]},), "on_summary": None}), + ("pull", {"on_records": ([[1], [2]],), + "on_success": ({"has_more": True},)}), + ] + if error == "server": + script.append( + ("pull", {"on_failure": ({"code": "Neo.ClientError.Made.Up"},), + "on_summary": None}) + ) + expected_error = ClientError + elif error == "connection": + script.append(("pull", ServiceUnavailable())) + expected_error = ServiceUnavailable + else: + raise ValueError(f"Unknown error type {error}") + connection.set_script(script) + + tx = AsyncTransaction( + connection, 2, lambda *args, **kwargs: None, + lambda *args, **kwargs: None, lambda *args, **kwargs: None + ) + res1 = await tx.run("UNWIND range(1, 1000) AS n RETURN n") + assert await res1.__anext__() == {"n": 1} + + res2 = await tx.run("RETURN 'causes error later'") + assert await res2.fetch(2) == [{"n": 1}, {"n": 2}] + with pytest.raises(expected_error) as exc1: + await res2.__anext__() + + # can finish the buffer + assert await res1.fetch(1) == [{"n": 2}] + # then fails because the connection was broken by res2 + with pytest.raises(ResultFailedError) as exc2: + await res1.__anext__() + + assert exc1.value is exc2.value.__cause__ diff --git a/tests/unit/sync/fixtures/fake_connection.py b/tests/unit/sync/fixtures/fake_connection.py index 659daebe9..e504ef50d 100644 --- a/tests/unit/sync/fixtures/fake_connection.py +++ b/tests/unit/sync/fixtures/fake_connection.py @@ -24,6 +24,7 @@ from neo4j._deadline import Deadline from neo4j._sync.io import Bolt from neo4j.auth_management import AuthManager +from neo4j.exceptions import Neo4jError __all__ = [ @@ -154,10 +155,12 @@ def set_script(self, callbacks): [ ("run", {"on_success": ({},), "on_summary": None}), ("pull", { + "on_records": ([some_record],), "on_success": None, "on_summary": None, - "on_records": }) + # use any exception to throw it instead of calling handlers + ("commit", RuntimeError("oh no!")) ] ``` Note that arguments can be `None`. In this case, ScriptedConnection @@ -180,6 +183,9 @@ def func(*args, **kwargs): self._script_pos += 1 def callback(): + if isinstance(scripted_callbacks, BaseException): + raise scripted_callbacks + error = None for cb_name, default_cb_args in ( ("on_ignored", ({},)), ("on_failure", ({},)), @@ -197,10 +203,14 @@ def callback(): if cb_args is None: cb_args = default_cb_args res = cb(*cb_args) + if cb_name == "on_failure": + error = Neo4jError.hydrate(**cb_args[0]) try: res # maybe the callback is async except TypeError: pass # or maybe it wasn't ;) + if error is not None: + raise error self.callbacks.append(callback) diff --git a/tests/unit/sync/work/test_transaction.py b/tests/unit/sync/work/test_transaction.py index 6bffe7846..d13681b02 100644 --- a/tests/unit/sync/work/test_transaction.py +++ b/tests/unit/sync/work/test_transaction.py @@ -17,7 +17,6 @@ from unittest.mock import MagicMock -from uuid import uuid4 import pytest @@ -26,6 +25,11 @@ Query, Transaction, ) +from neo4j.exceptions import ( + ClientError, + ResultFailedError, + ServiceUnavailable, +) from ...._async_compat import mark_sync_test @@ -275,3 +279,51 @@ def test_transaction_begin_pipelining( expected_calls.append(("send_all",)) expected_calls.append(("fetch_all",)) assert fake_connection.method_calls == expected_calls + + +@pytest.mark.parametrize("error", ("server", "connection")) +@mark_sync_test +def test_server_error_propagates(scripted_connection, error): + connection = scripted_connection + script = [ + # res 1 + ("run", {"on_success": ({"fields": ["n"]},), "on_summary": None}), + ("pull", {"on_records": ([[1], [2]],), + "on_success": ({"has_more": True},)}), + # res 2 + ("run", {"on_success": ({"fields": ["n"]},), "on_summary": None}), + ("pull", {"on_records": ([[1], [2]],), + "on_success": ({"has_more": True},)}), + ] + if error == "server": + script.append( + ("pull", {"on_failure": ({"code": "Neo.ClientError.Made.Up"},), + "on_summary": None}) + ) + expected_error = ClientError + elif error == "connection": + script.append(("pull", ServiceUnavailable())) + expected_error = ServiceUnavailable + else: + raise ValueError(f"Unknown error type {error}") + connection.set_script(script) + + tx = Transaction( + connection, 2, lambda *args, **kwargs: None, + lambda *args, **kwargs: None, lambda *args, **kwargs: None + ) + res1 = tx.run("UNWIND range(1, 1000) AS n RETURN n") + assert res1.__next__() == {"n": 1} + + res2 = tx.run("RETURN 'causes error later'") + assert res2.fetch(2) == [{"n": 1}, {"n": 2}] + with pytest.raises(expected_error) as exc1: + res2.__next__() + + # can finish the buffer + assert res1.fetch(1) == [{"n": 2}] + # then fails because the connection was broken by res2 + with pytest.raises(ResultFailedError) as exc2: + res1.__next__() + + assert exc1.value is exc2.value.__cause__