From a7261543a0e7ed2bbe3708686dfa998f3b3ab4d7 Mon Sep 17 00:00:00 2001 From: Robsdedude Date: Thu, 21 Dec 2023 14:57:03 +0100 Subject: [PATCH 1/2] Harden driver against unexpected RESET responses The server has been observed to reply with `FAILURE` and `IGNORED` to `RESET` requests. The former is according to spec and the driver should drop the connection (which it didn't), the latter isn't. The right combination of those two unexpected responses at the right time could get the driver stuck in an infinite loop. This change makes the driver drop the connection in either case to gracefully handle the situation. --- src/neo4j/_async/io/_bolt.py | 2 +- src/neo4j/_async/io/_bolt3.py | 16 +++++++--------- src/neo4j/_async/io/_bolt4.py | 16 +++++++--------- src/neo4j/_async/io/_bolt5.py | 11 +++-------- src/neo4j/_async/io/_common.py | 20 ++++++++++++++++++++ src/neo4j/_sync/io/_bolt.py | 2 +- src/neo4j/_sync/io/_bolt3.py | 16 +++++++--------- src/neo4j/_sync/io/_bolt4.py | 16 +++++++--------- src/neo4j/_sync/io/_bolt5.py | 11 +++-------- src/neo4j/_sync/io/_common.py | 20 ++++++++++++++++++++ 10 files changed, 76 insertions(+), 54 deletions(-) diff --git a/src/neo4j/_async/io/_bolt.py b/src/neo4j/_async/io/_bolt.py index 903c205c4..2efbc3a8c 100644 --- a/src/neo4j/_async/io/_bolt.py +++ b/src/neo4j/_async/io/_bolt.py @@ -857,7 +857,7 @@ async def fetch_all(self): messages fetched """ detail_count = summary_count = 0 - while self.responses: + while not self._closed and self.responses: response = self.responses[0] while not response.complete: detail_delta, summary_delta = await self.fetch_message() diff --git a/src/neo4j/_async/io/_bolt3.py b/src/neo4j/_async/io/_bolt3.py index d3c9837ad..9610ab72f 100644 --- a/src/neo4j/_async/io/_bolt3.py +++ b/src/neo4j/_async/io/_bolt3.py @@ -45,6 +45,7 @@ check_supported_server_product, CommitResponse, InitResponse, + ResetResponse, Response, ) @@ -391,17 +392,14 @@ def rollback(self, dehydration_hooks=None, hydration_hooks=None, dehydration_hooks=dehydration_hooks) async def reset(self, dehydration_hooks=None, hydration_hooks=None): - """ Add a RESET message to the outgoing queue, send - it and consume all remaining messages. - """ - - def fail(metadata): - raise BoltProtocolError("RESET failed %r" % metadata, address=self.unresolved_address) + """Reset the connection. + Add a RESET message to the outgoing queue, send it and consume all + remaining messages. + """ log.debug("[#%04X] C: RESET", self.local_port) - self._append(b"\x0F", - response=Response(self, "reset", hydration_hooks, - on_failure=fail), + response = ResetResponse(self, "reset", hydration_hooks) + self._append(b"\x0F", response=response, dehydration_hooks=dehydration_hooks) await self.send_all() await self.fetch_all() diff --git a/src/neo4j/_async/io/_bolt4.py b/src/neo4j/_async/io/_bolt4.py index bda0b9641..2d99a0f47 100644 --- a/src/neo4j/_async/io/_bolt4.py +++ b/src/neo4j/_async/io/_bolt4.py @@ -47,6 +47,7 @@ check_supported_server_product, CommitResponse, InitResponse, + ResetResponse, Response, ) @@ -311,17 +312,14 @@ def rollback(self, dehydration_hooks=None, hydration_hooks=None, dehydration_hooks=dehydration_hooks) async def reset(self, dehydration_hooks=None, hydration_hooks=None): - """ Add a RESET message to the outgoing queue, send - it and consume all remaining messages. - """ - - def fail(metadata): - raise BoltProtocolError("RESET failed %r" % metadata, self.unresolved_address) + """Reset the connection. + Add a RESET message to the outgoing queue, send it and consume all + remaining messages. + """ log.debug("[#%04X] C: RESET", self.local_port) - self._append(b"\x0F", - response=Response(self, "reset", hydration_hooks, - on_failure=fail), + response = ResetResponse(self, "reset", hydration_hooks) + self._append(b"\x0F", response=response, dehydration_hooks=dehydration_hooks) await self.send_all() await self.fetch_all() diff --git a/src/neo4j/_async/io/_bolt5.py b/src/neo4j/_async/io/_bolt5.py index 455eeb230..dda595d08 100644 --- a/src/neo4j/_async/io/_bolt5.py +++ b/src/neo4j/_async/io/_bolt5.py @@ -50,6 +50,7 @@ CommitResponse, InitResponse, LogonResponse, + ResetResponse, Response, ) @@ -314,15 +315,9 @@ async def reset(self, dehydration_hooks=None, hydration_hooks=None): Add a RESET message to the outgoing queue, send it and consume all remaining messages. """ - - def fail(metadata): - raise BoltProtocolError("RESET failed %r" % metadata, - self.unresolved_address) - log.debug("[#%04X] C: RESET", self.local_port) - self._append(b"\x0F", - response=Response(self, "reset", hydration_hooks, - on_failure=fail), + response = ResetResponse(self, "reset", hydration_hooks) + self._append(b"\x0F", response=response, dehydration_hooks=dehydration_hooks) await self.send_all() await self.fetch_all() diff --git a/src/neo4j/_async/io/_common.py b/src/neo4j/_async/io/_common.py index 3284abe2c..a5c5c2dfc 100644 --- a/src/neo4j/_async/io/_common.py +++ b/src/neo4j/_async/io/_common.py @@ -281,6 +281,26 @@ async def on_failure(self, metadata): raise Neo4jError.hydrate(**metadata) +class ResetResponse(Response): + async def _unexpected_message(self, response): + log.warning("[#%04X] _: RESET received %s " + "(unexpected response) => dropping connection", + self.connection.local_port, response) + await self.connection.close() + + async def on_records(self, records): + await self._unexpected_message("RECORD") + + async def on_success(self, metadata): + pass + + async def on_failure(self, metadata): + await self._unexpected_message("FAILURE") + + async def on_ignored(self, metadata=None): + await self._unexpected_message("IGNORED") + + class CommitResponse(Response): pass diff --git a/src/neo4j/_sync/io/_bolt.py b/src/neo4j/_sync/io/_bolt.py index 967d63036..69caaa6c2 100644 --- a/src/neo4j/_sync/io/_bolt.py +++ b/src/neo4j/_sync/io/_bolt.py @@ -857,7 +857,7 @@ def fetch_all(self): messages fetched """ detail_count = summary_count = 0 - while self.responses: + while not self._closed and self.responses: response = self.responses[0] while not response.complete: detail_delta, summary_delta = self.fetch_message() diff --git a/src/neo4j/_sync/io/_bolt3.py b/src/neo4j/_sync/io/_bolt3.py index 5d9741179..cf0fd4eba 100644 --- a/src/neo4j/_sync/io/_bolt3.py +++ b/src/neo4j/_sync/io/_bolt3.py @@ -45,6 +45,7 @@ check_supported_server_product, CommitResponse, InitResponse, + ResetResponse, Response, ) @@ -391,17 +392,14 @@ def rollback(self, dehydration_hooks=None, hydration_hooks=None, dehydration_hooks=dehydration_hooks) def reset(self, dehydration_hooks=None, hydration_hooks=None): - """ Add a RESET message to the outgoing queue, send - it and consume all remaining messages. - """ - - def fail(metadata): - raise BoltProtocolError("RESET failed %r" % metadata, address=self.unresolved_address) + """Reset the connection. + Add a RESET message to the outgoing queue, send it and consume all + remaining messages. + """ log.debug("[#%04X] C: RESET", self.local_port) - self._append(b"\x0F", - response=Response(self, "reset", hydration_hooks, - on_failure=fail), + response = ResetResponse(self, "reset", hydration_hooks) + self._append(b"\x0F", response=response, dehydration_hooks=dehydration_hooks) self.send_all() self.fetch_all() diff --git a/src/neo4j/_sync/io/_bolt4.py b/src/neo4j/_sync/io/_bolt4.py index 13969257b..04f092e63 100644 --- a/src/neo4j/_sync/io/_bolt4.py +++ b/src/neo4j/_sync/io/_bolt4.py @@ -47,6 +47,7 @@ check_supported_server_product, CommitResponse, InitResponse, + ResetResponse, Response, ) @@ -311,17 +312,14 @@ def rollback(self, dehydration_hooks=None, hydration_hooks=None, dehydration_hooks=dehydration_hooks) def reset(self, dehydration_hooks=None, hydration_hooks=None): - """ Add a RESET message to the outgoing queue, send - it and consume all remaining messages. - """ - - def fail(metadata): - raise BoltProtocolError("RESET failed %r" % metadata, self.unresolved_address) + """Reset the connection. + Add a RESET message to the outgoing queue, send it and consume all + remaining messages. + """ log.debug("[#%04X] C: RESET", self.local_port) - self._append(b"\x0F", - response=Response(self, "reset", hydration_hooks, - on_failure=fail), + response = ResetResponse(self, "reset", hydration_hooks) + self._append(b"\x0F", response=response, dehydration_hooks=dehydration_hooks) self.send_all() self.fetch_all() diff --git a/src/neo4j/_sync/io/_bolt5.py b/src/neo4j/_sync/io/_bolt5.py index 7f691f9b2..2c46daea0 100644 --- a/src/neo4j/_sync/io/_bolt5.py +++ b/src/neo4j/_sync/io/_bolt5.py @@ -50,6 +50,7 @@ CommitResponse, InitResponse, LogonResponse, + ResetResponse, Response, ) @@ -314,15 +315,9 @@ def reset(self, dehydration_hooks=None, hydration_hooks=None): Add a RESET message to the outgoing queue, send it and consume all remaining messages. """ - - def fail(metadata): - raise BoltProtocolError("RESET failed %r" % metadata, - self.unresolved_address) - log.debug("[#%04X] C: RESET", self.local_port) - self._append(b"\x0F", - response=Response(self, "reset", hydration_hooks, - on_failure=fail), + response = ResetResponse(self, "reset", hydration_hooks) + self._append(b"\x0F", response=response, dehydration_hooks=dehydration_hooks) self.send_all() self.fetch_all() diff --git a/src/neo4j/_sync/io/_common.py b/src/neo4j/_sync/io/_common.py index b609bd0ab..d1bd35172 100644 --- a/src/neo4j/_sync/io/_common.py +++ b/src/neo4j/_sync/io/_common.py @@ -281,6 +281,26 @@ def on_failure(self, metadata): raise Neo4jError.hydrate(**metadata) +class ResetResponse(Response): + def _unexpected_message(self, response): + log.warning("[#%04X] _: RESET received %s " + "(unexpected response) => dropping connection", + self.connection.local_port, response) + self.connection.close() + + def on_records(self, records): + self._unexpected_message("RECORD") + + def on_success(self, metadata): + pass + + def on_failure(self, metadata): + self._unexpected_message("FAILURE") + + def on_ignored(self, metadata=None): + self._unexpected_message("IGNORED") + + class CommitResponse(Response): pass From 145df765cb634282c02aa0846db19ce5ab5848a1 Mon Sep 17 00:00:00 2001 From: Robsdedude Date: Fri, 22 Dec 2023 13:05:57 +0100 Subject: [PATCH 2/2] Add unit tests for ResetResponse handler --- tests/unit/async_/io/test__common.py | 107 ++++++++++++++++++++++++++- tests/unit/sync/io/test__common.py | 107 ++++++++++++++++++++++++++- 2 files changed, 212 insertions(+), 2 deletions(-) diff --git a/tests/unit/async_/io/test__common.py b/tests/unit/async_/io/test__common.py index 10d8cffc8..972f9e252 100644 --- a/tests/unit/async_/io/test__common.py +++ b/tests/unit/async_/io/test__common.py @@ -14,9 +14,14 @@ # limitations under the License. +import logging + import pytest -from neo4j._async.io._common import AsyncOutbox +from neo4j._async.io._common import ( + AsyncOutbox, + ResetResponse, +) from neo4j._codec.packstream.v1 import PackableBuffer from ...._async_compat import mark_async_test @@ -56,3 +61,103 @@ async def test_async_outbox_chunking(chunk_size, data, result, mocker): assert not await outbox.flush() socket_mock.sendall.assert_awaited_once() + + +def get_handler_arg(response): + if response == "RECORD": + return [] + elif response == "IGNORED": + return {} + elif response == "FAILURE": + return {} + elif response == "SUCCESS": + return {} + else: + raise ValueError(f"Unexpected response: {response}") + + +def call_handler(handler, response, arg=None): + if arg is None: + arg = get_handler_arg(response) + + if response == "RECORD": + return handler.on_records(arg) + elif response == "IGNORED": + return handler.on_ignored(arg) + elif response == "FAILURE": + return handler.on_failure(arg) + elif response == "SUCCESS": + return handler.on_success(arg) + else: + raise ValueError(f"Unexpected response: {response}") + + +@pytest.mark.parametrize( + ("response", "unexpected"), + ( + ("RECORD", True), + ("IGNORED", True), + ("FAILURE", True), + ("SUCCESS", False), + ) +) +@mark_async_test +async def test_reset_response_closes_connection_on_unexpected_responses( + response, unexpected, async_fake_connection +): + handler = ResetResponse(async_fake_connection, "reset", {}) + async_fake_connection.close.assert_not_called() + + await call_handler(handler, response) + + if unexpected: + async_fake_connection.close.assert_awaited_once() + else: + async_fake_connection.close.assert_not_called() + + +@pytest.mark.parametrize( + ("response", "unexpected"), + ( + ("RECORD", True), + ("IGNORED", True), + ("FAILURE", True), + ("SUCCESS", False), + ) +) +@mark_async_test +async def test_reset_response_logs_warning_on_unexpected_responses( + response, unexpected, async_fake_connection, caplog +): + handler = ResetResponse(async_fake_connection, "reset", {}) + + with caplog.at_level(logging.WARNING): + await call_handler(handler, response) + + log_message_found = any("RESET" in msg and "unexpected response" in msg + for msg in caplog.messages) + if unexpected: + assert log_message_found + else: + assert not log_message_found + + +@pytest.mark.parametrize("response", + ("RECORD", "IGNORED", "FAILURE", "SUCCESS")) +@mark_async_test +async def test_reset_response_never_calls_handlers( + response, async_fake_connection, mocker +): + handlers = { + key: mocker.AsyncMock(name=key) + for key in + ("on_records", "on_ignored", "on_failure", "on_success", "on_summary") + } + + handler = ResetResponse(async_fake_connection, "reset", {}, **handlers) + + arg = get_handler_arg(response) + await call_handler(handler, response, arg) + + for handler in handlers.values(): + handler.assert_not_called() diff --git a/tests/unit/sync/io/test__common.py b/tests/unit/sync/io/test__common.py index d2c6df233..03e63996a 100644 --- a/tests/unit/sync/io/test__common.py +++ b/tests/unit/sync/io/test__common.py @@ -14,10 +14,15 @@ # limitations under the License. +import logging + import pytest from neo4j._codec.packstream.v1 import PackableBuffer -from neo4j._sync.io._common import Outbox +from neo4j._sync.io._common import ( + Outbox, + ResetResponse, +) from ...._async_compat import mark_sync_test @@ -56,3 +61,103 @@ def test_async_outbox_chunking(chunk_size, data, result, mocker): assert not outbox.flush() socket_mock.sendall.assert_called_once() + + +def get_handler_arg(response): + if response == "RECORD": + return [] + elif response == "IGNORED": + return {} + elif response == "FAILURE": + return {} + elif response == "SUCCESS": + return {} + else: + raise ValueError(f"Unexpected response: {response}") + + +def call_handler(handler, response, arg=None): + if arg is None: + arg = get_handler_arg(response) + + if response == "RECORD": + return handler.on_records(arg) + elif response == "IGNORED": + return handler.on_ignored(arg) + elif response == "FAILURE": + return handler.on_failure(arg) + elif response == "SUCCESS": + return handler.on_success(arg) + else: + raise ValueError(f"Unexpected response: {response}") + + +@pytest.mark.parametrize( + ("response", "unexpected"), + ( + ("RECORD", True), + ("IGNORED", True), + ("FAILURE", True), + ("SUCCESS", False), + ) +) +@mark_sync_test +def test_reset_response_closes_connection_on_unexpected_responses( + response, unexpected, fake_connection +): + handler = ResetResponse(fake_connection, "reset", {}) + fake_connection.close.assert_not_called() + + call_handler(handler, response) + + if unexpected: + fake_connection.close.assert_called_once() + else: + fake_connection.close.assert_not_called() + + +@pytest.mark.parametrize( + ("response", "unexpected"), + ( + ("RECORD", True), + ("IGNORED", True), + ("FAILURE", True), + ("SUCCESS", False), + ) +) +@mark_sync_test +def test_reset_response_logs_warning_on_unexpected_responses( + response, unexpected, fake_connection, caplog +): + handler = ResetResponse(fake_connection, "reset", {}) + + with caplog.at_level(logging.WARNING): + call_handler(handler, response) + + log_message_found = any("RESET" in msg and "unexpected response" in msg + for msg in caplog.messages) + if unexpected: + assert log_message_found + else: + assert not log_message_found + + +@pytest.mark.parametrize("response", + ("RECORD", "IGNORED", "FAILURE", "SUCCESS")) +@mark_sync_test +def test_reset_response_never_calls_handlers( + response, fake_connection, mocker +): + handlers = { + key: mocker.MagicMock(name=key) + for key in + ("on_records", "on_ignored", "on_failure", "on_success", "on_summary") + } + + handler = ResetResponse(fake_connection, "reset", {}, **handlers) + + arg = get_handler_arg(response) + call_handler(handler, response, arg) + + for handler in handlers.values(): + handler.assert_not_called()