diff --git a/neo4j/io/__init__.py b/neo4j/io/__init__.py index 0afa6b6c3..cdda93c3c 100644 --- a/neo4j/io/__init__.py +++ b/neo4j/io/__init__.py @@ -1235,7 +1235,13 @@ def update_routing_table(self, *, database, imp_user, bookmarks, raise ServiceUnavailable("Unable to retrieve routing information") def update_connection_pool(self, *, database): - servers = self.get_or_create_routing_table(database).servers() + with self.refresh_lock: + routing_tables = [self.get_or_create_routing_table(database)] + for db in self.routing_tables.keys(): + if db == database: + continue + routing_tables.append(self.routing_tables[db]) + servers = set.union(*(rt.servers() for rt in routing_tables)) for address in list(self.connections): if address.unresolved not in servers: super(Neo4jPool, self).deactivate(address) diff --git a/tests/unit/conftest.py b/tests/unit/conftest.py new file mode 100644 index 000000000..c03179758 --- /dev/null +++ b/tests/unit/conftest.py @@ -0,0 +1,19 @@ +# Copyright (c) "Neo4j" +# Neo4j Sweden AB [http://neo4j.com] +# +# This file is part of Neo4j. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +from .fixtures import * # necessary for pytest to discover the fixtures diff --git a/tests/unit/fixtures/__init__.py b/tests/unit/fixtures/__init__.py new file mode 100644 index 000000000..f17c825e2 --- /dev/null +++ b/tests/unit/fixtures/__init__.py @@ -0,0 +1,19 @@ +# Copyright (c) "Neo4j" +# Neo4j Sweden AB [http://neo4j.com] +# +# This file is part of Neo4j. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +from ._fake_connection import * diff --git a/tests/unit/fixtures/_fake_connection.py b/tests/unit/fixtures/_fake_connection.py new file mode 100644 index 000000000..36d610c0c --- /dev/null +++ b/tests/unit/fixtures/_fake_connection.py @@ -0,0 +1,122 @@ +# Copyright (c) "Neo4j" +# Neo4j Sweden AB [http://neo4j.com] +# +# This file is part of Neo4j. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import inspect + +import pytest + +from neo4j import ServerInfo +from neo4j._deadline import Deadline + + +__all__ = [ + "fake_connection", + "fake_connection_generator", +] + + +@pytest.fixture +def fake_connection_generator(session_mocker): + mock = session_mocker.mock_module + + class FakeConnection(mock.NonCallableMagicMock): + callbacks = [] + server_info = ServerInfo("127.0.0.1", (4, 3)) + local_port = 1234 + bolt_patches = set() + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self.attach_mock(mock.Mock(return_value=True), "is_reset_mock") + self.attach_mock(mock.Mock(return_value=False), "defunct") + self.attach_mock(mock.Mock(return_value=False), "stale") + self.attach_mock(mock.Mock(return_value=False), "closed") + self.attach_mock(mock.Mock(return_value=False), "socket") + self.socket.attach_mock( + mock.Mock(return_value=None), "get_deadline" + ) + + def set_deadline_side_effect(deadline): + deadline = Deadline.from_timeout_or_deadline(deadline) + self.socket.get_deadline.return_value = deadline + + self.socket.attach_mock( + mock.Mock(side_effect=set_deadline_side_effect), "set_deadline" + ) + + def close_side_effect(): + self.closed.return_value = True + + self.attach_mock(mock.Mock(side_effect=close_side_effect), "close") + + @property + def is_reset(self): + if self.closed.return_value or self.defunct.return_value: + raise AssertionError("is_reset should not be called on a closed or " + "defunct connection.") + return self.is_reset_mock() + + def fetch_message(self, *args, **kwargs): + if self.callbacks: + cb = self.callbacks.pop(0) + cb() + return super().__getattr__("fetch_message")(*args, **kwargs) + + def fetch_all(self, *args, **kwargs): + while self.callbacks: + cb = self.callbacks.pop(0) + cb() + return super().__getattr__("fetch_all")(*args, **kwargs) + + def __getattr__(self, name): + parent = super() + + def build_message_handler(name): + def func(*args, **kwargs): + def callback(): + for cb_name, param_count in ( + ("on_success", 1), + ("on_summary", 0) + ): + cb = kwargs.get(cb_name, None) + if callable(cb): + try: + param_count = \ + len(inspect.signature(cb).parameters) + except ValueError: + # e.g. built-in method as cb + pass + if param_count == 1: + cb({}) + else: + cb() + self.callbacks.append(callback) + + return func + + method_mock = parent.__getattr__(name) + if name in ("run", "commit", "pull", "rollback", "discard"): + method_mock.side_effect = build_message_handler(name) + return method_mock + + return FakeConnection + + +@pytest.fixture +def fake_connection(fake_connection_generator): + return fake_connection_generator() diff --git a/tests/unit/io/test__common.py b/tests/unit/io/test__common.py index 72f9ee921..b70357c17 100644 --- a/tests/unit/io/test__common.py +++ b/tests/unit/io/test__common.py @@ -25,8 +25,6 @@ ResetResponse, ) -from ..work import fake_connection - @pytest.mark.parametrize(("chunk_size", "data", "result"), ( ( diff --git a/tests/unit/io/test_neo4j_pool.py b/tests/unit/io/test_neo4j_pool.py index 730875349..4d2eff54b 100644 --- a/tests/unit/io/test_neo4j_pool.py +++ b/tests/unit/io/test_neo4j_pool.py @@ -2,15 +2,13 @@ # -*- encoding: utf-8 -*- # Copyright (c) "Neo4j" -# Neo4j Sweden AB [http://neo4j.com] -# -# This file is part of Neo4j. +# Neo4j Sweden AB [https://neo4j.com] # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # -# http://www.apache.org/licenses/LICENSE-2.0 +# https://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, @@ -19,62 +17,99 @@ # limitations under the License. -from unittest.mock import Mock +import inspect import pytest -from ..work import FakeConnection - from neo4j import ( READ_ACCESS, WRITE_ACCESS, ) +from neo4j._deadline import Deadline from neo4j.addressing import ResolvedAddress from neo4j.conf import ( PoolConfig, RoutingConfig, - WorkspaceConfig + WorkspaceConfig, ) -from neo4j._deadline import Deadline from neo4j.exceptions import ( ServiceUnavailable, - SessionExpired + SessionExpired, ) from neo4j.io import Neo4jPool -ROUTER_ADDRESS = ResolvedAddress(("1.2.3.1", 9001), host_name="host") -READER_ADDRESS = ResolvedAddress(("1.2.3.1", 9002), host_name="host") -WRITER_ADDRESS = ResolvedAddress(("1.2.3.1", 9003), host_name="host") - - -@pytest.fixture() -def opener(): - def open_(addr, timeout): - connection = FakeConnection() - connection.addr = addr - connection.timeout = timeout - route_mock = Mock() - route_mock.return_value = [{ - "ttl": 1000, - "servers": [ - {"addresses": [str(ROUTER_ADDRESS)], "role": "ROUTE"}, - {"addresses": [str(READER_ADDRESS)], "role": "READ"}, - {"addresses": [str(WRITER_ADDRESS)], "role": "WRITE"}, - ], - }] - connection.attach_mock(route_mock, "route") - opener_.connections.append(connection) - return connection - - opener_ = Mock() - opener_.connections = [] - opener_.side_effect = open_ - return opener_ +ROUTER1_ADDRESS = ResolvedAddress(("1.2.3.1", 9000), host_name="host") +ROUTER2_ADDRESS = ResolvedAddress(("1.2.3.1", 9001), host_name="host") +ROUTER3_ADDRESS = ResolvedAddress(("1.2.3.1", 9002), host_name="host") +READER1_ADDRESS = ResolvedAddress(("1.2.3.1", 9010), host_name="host") +READER2_ADDRESS = ResolvedAddress(("1.2.3.1", 9011), host_name="host") +READER3_ADDRESS = ResolvedAddress(("1.2.3.1", 9012), host_name="host") +WRITER1_ADDRESS = ResolvedAddress(("1.2.3.1", 9020), host_name="host") + + +@pytest.fixture +def custom_routing_opener(fake_connection_generator, mocker): + def make_opener(failures=None, get_readers=None): + def routing_side_effect(*args, **kwargs): + nonlocal failures + res = next(failures, None) + if res is None: + if get_readers is not None: + readers = get_readers(kwargs.get("database", args[0])) + else: + readers = [str(READER1_ADDRESS)] + return [{ + "ttl": 1000, + "servers": [ + {"addresses": [str(ROUTER1_ADDRESS), + str(ROUTER2_ADDRESS), + str(ROUTER3_ADDRESS)], + "role": "ROUTE"}, + {"addresses": readers, "role": "READ"}, + {"addresses": [str(WRITER1_ADDRESS)], "role": "WRITE"}, + ], + }] + raise res + + def open_(addr, deadline): + connection = fake_connection_generator() + connection.unresolved_address = addr + connection.deadline = deadline + route_mock = mocker.MagicMock() + + route_mock.side_effect = routing_side_effect + connection.attach_mock(route_mock, "route") + opener_.connections.append(connection) + return connection + + failures = iter(failures or []) + opener_ = mocker.MagicMock() + opener_.connections = [] + opener_.side_effect = open_ + return opener_ + + return make_opener + + +@pytest.fixture +def opener(custom_routing_opener): + return custom_routing_opener() + + +def _pool_config(): + pool_config = PoolConfig() + return pool_config + + +def _simple_pool(opener) -> Neo4jPool: + return Neo4jPool( + opener, _pool_config(), WorkspaceConfig(), ROUTER1_ADDRESS + ) def test_acquires_new_routing_table_if_deleted(opener): - pool = Neo4jPool(opener, PoolConfig(), WorkspaceConfig(), ROUTER_ADDRESS) + pool = _simple_pool(opener) cx = pool.acquire(READ_ACCESS, 30, 60, "test_db", None) pool.release(cx) assert pool.routing_tables.get("test_db") @@ -87,7 +122,7 @@ def test_acquires_new_routing_table_if_deleted(opener): def test_acquires_new_routing_table_if_stale(opener): - pool = Neo4jPool(opener, PoolConfig(), WorkspaceConfig(), ROUTER_ADDRESS) + pool = _simple_pool(opener) cx = pool.acquire(READ_ACCESS, 30, 60, "test_db", None) pool.release(cx) assert pool.routing_tables.get("test_db") @@ -101,7 +136,7 @@ def test_acquires_new_routing_table_if_stale(opener): def test_removes_old_routing_table(opener): - pool = Neo4jPool(opener, PoolConfig(), WorkspaceConfig(), ROUTER_ADDRESS) + pool = _simple_pool(opener) cx = pool.acquire(READ_ACCESS, 30, 60, "test_db1", None) pool.release(cx) assert pool.routing_tables.get("test_db1") @@ -122,18 +157,18 @@ def test_removes_old_routing_table(opener): @pytest.mark.parametrize("type_", ("r", "w")) def test_chooses_right_connection_type(opener, type_): - pool = Neo4jPool(opener, PoolConfig(), WorkspaceConfig(), ROUTER_ADDRESS) + pool = _simple_pool(opener) cx1 = pool.acquire(READ_ACCESS if type_ == "r" else WRITE_ACCESS, 30, 60, "test_db", None) pool.release(cx1) if type_ == "r": - assert cx1.addr == READER_ADDRESS + assert cx1.unresolved_address == READER1_ADDRESS else: - assert cx1.addr == WRITER_ADDRESS + assert cx1.unresolved_address == WRITER1_ADDRESS def test_reuses_connection(opener): - pool = Neo4jPool(opener, PoolConfig(), WorkspaceConfig(), ROUTER_ADDRESS) + pool = _simple_pool(opener) cx1 = pool.acquire(READ_ACCESS, 30, 60, "test_db", None) pool.release(cx1) cx2 = pool.acquire(READ_ACCESS, 30, 60, "test_db", None) @@ -143,17 +178,19 @@ def test_reuses_connection(opener): @pytest.mark.parametrize("break_on_close", (True, False)) def test_closes_stale_connections(opener, break_on_close): def break_connection(): - pool.deactivate(cx1.addr) + pool.deactivate(cx1.unresolved_address) if cx_close_mock_side_effect: - cx_close_mock_side_effect() + res = cx_close_mock_side_effect() + if inspect.isawaitable(res): + return res - pool = Neo4jPool(opener, PoolConfig(), WorkspaceConfig(), ROUTER_ADDRESS) + pool = _simple_pool(opener) cx1 = pool.acquire(READ_ACCESS, 30, 60, "test_db", None) pool.release(cx1) - assert cx1 in pool.connections[cx1.addr] - # simulate connection going stale (e.g. exceeding) and then breaking when - # the pool tries to close the connection + assert cx1 in pool.connections[cx1.unresolved_address] + # simulate connection going stale (e.g. exceeding idle timeout) and then + # breaking when the pool tries to close the connection cx1.stale.return_value = True cx_close_mock = cx1.close if break_on_close: @@ -166,24 +203,25 @@ def break_connection(): else: cx1.close.assert_called_once() assert cx2 is not cx1 - assert cx2.addr == cx1.addr - assert cx1 not in pool.connections[cx1.addr] - assert cx2 in pool.connections[cx2.addr] + assert cx2.unresolved_address == cx1.unresolved_address + assert cx1 not in pool.connections[cx1.unresolved_address] + assert cx2 in pool.connections[cx2.unresolved_address] def test_does_not_close_stale_connections_in_use(opener): - pool = Neo4jPool(opener, PoolConfig(), WorkspaceConfig(), ROUTER_ADDRESS) + pool = _simple_pool(opener) cx1 = pool.acquire(READ_ACCESS, 30, 60, "test_db", None) - assert cx1 in pool.connections[cx1.addr] - # simulate connection going stale (e.g. exceeding) while being in use + assert cx1 in pool.connections[cx1.unresolved_address] + # simulate connection going stale (e.g. exceeding idle timeout) while being + # in use cx1.stale.return_value = True cx2 = pool.acquire(READ_ACCESS, 30, 60, "test_db", None) pool.release(cx2) cx1.close.assert_not_called() assert cx2 is not cx1 - assert cx2.addr == cx1.addr - assert cx1 in pool.connections[cx1.addr] - assert cx2 in pool.connections[cx2.addr] + assert cx2.unresolved_address == cx1.unresolved_address + assert cx1 in pool.connections[cx1.unresolved_address] + assert cx2 in pool.connections[cx2.unresolved_address] pool.release(cx1) # now that cx1 is back in the pool and still stale, @@ -194,13 +232,13 @@ def test_does_not_close_stale_connections_in_use(opener): pool.release(cx3) cx1.close.assert_called_once() assert cx2 is cx3 - assert cx3.addr == cx1.addr - assert cx1 not in pool.connections[cx1.addr] - assert cx3 in pool.connections[cx2.addr] + assert cx3.unresolved_address == cx1.unresolved_address + assert cx1 not in pool.connections[cx1.unresolved_address] + assert cx3 in pool.connections[cx2.unresolved_address] def test_release_resets_connections(opener): - pool = Neo4jPool(opener, PoolConfig(), WorkspaceConfig(), ROUTER_ADDRESS) + pool = _simple_pool(opener) cx1 = pool.acquire(READ_ACCESS, 30, 60, "test_db", None) cx1.is_reset_mock.return_value = False cx1.is_reset_mock.reset_mock() @@ -210,40 +248,41 @@ def test_release_resets_connections(opener): def test_release_does_not_resets_closed_connections(opener): - pool = Neo4jPool(opener, PoolConfig(), WorkspaceConfig(), ROUTER_ADDRESS) + pool = _simple_pool(opener) cx1 = pool.acquire(READ_ACCESS, 30, 60, "test_db", None) cx1.closed.return_value = True cx1.closed.reset_mock() cx1.is_reset_mock.reset_mock() pool.release(cx1) cx1.closed.assert_called_once() - cx1.is_reset_mock.asset_not_called() - cx1.reset.asset_not_called() + cx1.is_reset_mock.assert_not_called() + cx1.reset.assert_not_called() def test_release_does_not_resets_defunct_connections(opener): - pool = Neo4jPool(opener, PoolConfig(), WorkspaceConfig(), ROUTER_ADDRESS) + pool = _simple_pool(opener) cx1 = pool.acquire(READ_ACCESS, 30, 60, "test_db", None) cx1.defunct.return_value = True cx1.defunct.reset_mock() cx1.is_reset_mock.reset_mock() pool.release(cx1) cx1.defunct.assert_called_once() - cx1.is_reset_mock.asset_not_called() - cx1.reset.asset_not_called() + cx1.is_reset_mock.assert_not_called() + cx1.reset.assert_not_called() -def test_multiple_broken_connections_on_close(opener): +def test_multiple_broken_connections_on_close(opener, mocker): def mock_connection_breaks_on_close(cx): def close_side_effect(): cx.closed.return_value = True cx.defunct.return_value = True - pool.deactivate(READER_ADDRESS) + pool.deactivate(READER1_ADDRESS) - cx.attach_mock(Mock(side_effect=close_side_effect), "close") + cx.attach_mock(mocker.MagicMock(side_effect=close_side_effect), + "close") # create pool with 2 idle connections - pool = Neo4jPool(opener, PoolConfig(), WorkspaceConfig(), ROUTER_ADDRESS) + pool = _simple_pool(opener) cx1 = pool.acquire(READ_ACCESS, 30, 60, "test_db", None) cx2 = pool.acquire(READ_ACCESS, 30, 60, "test_db", None) pool.release(cx1) @@ -264,37 +303,115 @@ def close_side_effect(): def test_failing_opener_leaves_connections_in_use_alone(opener): - pool = Neo4jPool(opener, PoolConfig(), WorkspaceConfig(), ROUTER_ADDRESS) + pool = _simple_pool(opener) cx1 = pool.acquire(READ_ACCESS, 30, 60, "test_db", None) opener.side_effect = ServiceUnavailable("Server overloaded") with pytest.raises((ServiceUnavailable, SessionExpired)): pool.acquire(READ_ACCESS, 30, 60, "test_db", None) - assert not cx1.closed() def test__acquire_new_later_with_room(opener): - config = PoolConfig() + config = _pool_config() config.max_connection_pool_size = 1 pool = Neo4jPool( - opener, config, WorkspaceConfig(), ROUTER_ADDRESS + opener, config, WorkspaceConfig(), ROUTER1_ADDRESS ) - assert pool.connections_reservations[READER_ADDRESS] == 0 - creator = pool._acquire_new_later(READER_ADDRESS, Deadline(1)) - assert pool.connections_reservations[READER_ADDRESS] == 1 + assert pool.connections_reservations[READER1_ADDRESS] == 0 + creator = pool._acquire_new_later(READER1_ADDRESS, Deadline(1)) + assert pool.connections_reservations[READER1_ADDRESS] == 1 assert callable(creator) def test__acquire_new_later_without_room(opener): - config = PoolConfig() + config = _pool_config() config.max_connection_pool_size = 1 pool = Neo4jPool( - opener, config, WorkspaceConfig(), ROUTER_ADDRESS + opener, config, WorkspaceConfig(), ROUTER1_ADDRESS ) _ = pool.acquire(READ_ACCESS, 30, 60, "test_db", None) # pool is full now - assert pool.connections_reservations[READER_ADDRESS] == 0 - creator = pool._acquire_new_later(READER_ADDRESS, Deadline(1)) - assert pool.connections_reservations[READER_ADDRESS] == 0 + assert pool.connections_reservations[READER1_ADDRESS] == 0 + creator = pool._acquire_new_later(READER1_ADDRESS, Deadline(1)) + assert pool.connections_reservations[READER1_ADDRESS] == 0 assert creator is None + + +def test_pool_closes_connections_dropped_from_rt(custom_routing_opener): + readers = {"db1": [str(READER1_ADDRESS)]} + + def get_readers(database): + return readers[database] + + opener = custom_routing_opener(get_readers=get_readers) + + pool = Neo4jPool( + opener, _pool_config(), WorkspaceConfig(), ROUTER1_ADDRESS + ) + cx1 = pool.acquire(READ_ACCESS, 30, 60, "db1", None) + assert cx1.unresolved_address == READER1_ADDRESS + pool.release(cx1) + + cx1.close.assert_not_called() + assert len(pool.connections[READER1_ADDRESS]) == 1 + + # force RT refresh, returning a different reader + del pool.routing_tables["db1"] + readers["db1"] = [str(READER2_ADDRESS)] + + cx2 = pool.acquire(READ_ACCESS, 30, 60, "db1", None) + assert cx2.unresolved_address == READER2_ADDRESS + + cx1.close.assert_called_once() + assert len(pool.connections[READER1_ADDRESS]) == 0 + + pool.release(cx2) + assert len(pool.connections[READER2_ADDRESS]) == 1 + + +def test_pool_does_not_close_connections_dropped_from_rt_for_other_server( + custom_routing_opener +): + readers = { + "db1": [str(READER1_ADDRESS), str(READER2_ADDRESS)], + "db2": [str(READER1_ADDRESS)] + } + + def get_readers(database): + return readers[database] + + opener = custom_routing_opener(get_readers=get_readers) + + pool = Neo4jPool( + opener, _pool_config(), WorkspaceConfig(), ROUTER1_ADDRESS + ) + cx1 = pool.acquire(READ_ACCESS, 30, 60, "db1", None) + pool.release(cx1) + assert cx1.unresolved_address in (READER1_ADDRESS, READER2_ADDRESS) + reader1_connection_count = len(pool.connections[READER1_ADDRESS]) + reader2_connection_count = len(pool.connections[READER2_ADDRESS]) + assert reader1_connection_count + reader2_connection_count == 1 + + cx2 = pool.acquire(READ_ACCESS, 30, 60, "db2", None) + pool.release(cx2) + assert cx2.unresolved_address == READER1_ADDRESS + cx1.close.assert_not_called() + cx2.close.assert_not_called() + assert len(pool.connections[READER1_ADDRESS]) == 1 + assert len(pool.connections[READER2_ADDRESS]) == reader2_connection_count + + # force RT refresh, returning a different reader + del pool.routing_tables["db2"] + readers["db2"] = [str(READER3_ADDRESS)] + + cx3 = pool.acquire(READ_ACCESS, 30, 60, "db2", None) + pool.release(cx3) + assert cx3.unresolved_address == READER3_ADDRESS + + cx1.close.assert_not_called() + cx2.close.assert_not_called() + cx3.close.assert_not_called() + assert len(pool.connections[READER1_ADDRESS]) == 1 + assert len(pool.connections[READER2_ADDRESS]) == reader2_connection_count + assert len(pool.connections[READER3_ADDRESS]) == 1 diff --git a/tests/unit/work/__init__.py b/tests/unit/work/__init__.py index 238e61d3f..e69de29bb 100644 --- a/tests/unit/work/__init__.py +++ b/tests/unit/work/__init__.py @@ -1,24 +0,0 @@ -#!/usr/bin/env python -# -*- encoding: utf-8 -*- - -# Copyright (c) "Neo4j" -# Neo4j Sweden AB [http://neo4j.com] -# -# This file is part of Neo4j. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -from ._fake_connection import ( - FakeConnection, - fake_connection, -) diff --git a/tests/unit/work/_fake_connection.py b/tests/unit/work/_fake_connection.py deleted file mode 100644 index 9bc6815ad..000000000 --- a/tests/unit/work/_fake_connection.py +++ /dev/null @@ -1,114 +0,0 @@ -#!/usr/bin/env python -# -*- encoding: utf-8 -*- - -# Copyright (c) "Neo4j" -# Neo4j Sweden AB [http://neo4j.com] -# -# This file is part of Neo4j. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - - -import inspect -from unittest import mock - -import pytest - -from neo4j import ServerInfo -from neo4j._deadline import Deadline - - -class FakeConnection(mock.NonCallableMagicMock): - callbacks = [] - server_info = ServerInfo("127.0.0.1", (4, 3)) - local_port = 1234 - bolt_patches = set() - - def __init__(self, *args, **kwargs): - super().__init__(*args, **kwargs) - self.attach_mock(mock.Mock(return_value=True), "is_reset_mock") - self.attach_mock(mock.Mock(return_value=False), "defunct") - self.attach_mock(mock.Mock(return_value=False), "stale") - self.attach_mock(mock.Mock(return_value=False), "closed") - self.attach_mock(mock.Mock(return_value=False), "socket") - self.socket.attach_mock( - mock.Mock(return_value=None), "get_deadline" - ) - - def set_deadline_side_effect(deadline): - deadline = Deadline.from_timeout_or_deadline(deadline) - self.socket.get_deadline.return_value = deadline - - self.socket.attach_mock( - mock.Mock(side_effect=set_deadline_side_effect), "set_deadline" - ) - - def close_side_effect(): - self.closed.return_value = True - - self.attach_mock(mock.Mock(side_effect=close_side_effect), "close") - - @property - def is_reset(self): - if self.closed.return_value or self.defunct.return_value: - raise AssertionError("is_reset should not be called on a closed or " - "defunct connection.") - return self.is_reset_mock() - - def fetch_message(self, *args, **kwargs): - if self.callbacks: - cb = self.callbacks.pop(0) - cb() - return super().__getattr__("fetch_message")(*args, **kwargs) - - def fetch_all(self, *args, **kwargs): - while self.callbacks: - cb = self.callbacks.pop(0) - cb() - return super().__getattr__("fetch_all")(*args, **kwargs) - - def __getattr__(self, name): - parent = super() - - def build_message_handler(name): - def func(*args, **kwargs): - def callback(): - for cb_name, param_count in ( - ("on_success", 1), - ("on_summary", 0) - ): - cb = kwargs.get(cb_name, None) - if callable(cb): - try: - param_count = \ - len(inspect.signature(cb).parameters) - except ValueError: - # e.g. built-in method as cb - pass - if param_count == 1: - cb({}) - else: - cb() - self.callbacks.append(callback) - - return func - - method_mock = parent.__getattr__(name) - if name in ("run", "commit", "pull", "rollback", "discard"): - method_mock.side_effect = build_message_handler(name) - return method_mock - - -@pytest.fixture -def fake_connection(): - return FakeConnection() diff --git a/tests/unit/work/test_session.py b/tests/unit/work/test_session.py index d43048699..64a9f494b 100644 --- a/tests/unit/work/test_session.py +++ b/tests/unit/work/test_session.py @@ -30,17 +30,15 @@ ) from neo4j.io import IOPool -from ._fake_connection import FakeConnection - @pytest.fixture -def pool(mocker): +def pool(mocker, fake_connection_generator): pool = mocker.Mock(spec=IOPool) assert not hasattr(pool, "acquired_connection_mocks") pool.acquired_connection_mocks = [] def acquire_side_effect(*_, **__): - connection = FakeConnection() + connection = fake_connection_generator() pool.acquired_connection_mocks.append(connection) return connection diff --git a/tests/unit/work/test_transaction.py b/tests/unit/work/test_transaction.py index 06e755662..3ac206ce2 100644 --- a/tests/unit/work/test_transaction.py +++ b/tests/unit/work/test_transaction.py @@ -31,8 +31,6 @@ Transaction, ) -from ._fake_connection import fake_connection - @pytest.mark.parametrize(("explicit_commit", "close"), ( (False, False),