Skip to content

Commit 249c10b

Browse files
authored
[4.4] Fix pool closing connections too aggressively (#1038)
Whenever a new routing table was fetched, the pool would close all connections to servers that were not part of the routing table. However, it might well be, that a missing server is present still in the routing table for another database. Hence, the pool now checks the routing tables for all databases before deciding which connections are no longer needed.g
1 parent 0823655 commit 249c10b

10 files changed

+372
-233
lines changed

neo4j/io/__init__.py

+7-1
Original file line numberDiff line numberDiff line change
@@ -1235,7 +1235,13 @@ def update_routing_table(self, *, database, imp_user, bookmarks,
12351235
raise ServiceUnavailable("Unable to retrieve routing information")
12361236

12371237
def update_connection_pool(self, *, database):
1238-
servers = self.get_or_create_routing_table(database).servers()
1238+
with self.refresh_lock:
1239+
routing_tables = [self.get_or_create_routing_table(database)]
1240+
for db in self.routing_tables.keys():
1241+
if db == database:
1242+
continue
1243+
routing_tables.append(self.routing_tables[db])
1244+
servers = set.union(*(rt.servers() for rt in routing_tables))
12391245
for address in list(self.connections):
12401246
if address.unresolved not in servers:
12411247
super(Neo4jPool, self).deactivate(address)

tests/unit/conftest.py

+19
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,19 @@
1+
# Copyright (c) "Neo4j"
2+
# Neo4j Sweden AB [http://neo4j.com]
3+
#
4+
# This file is part of Neo4j.
5+
#
6+
# Licensed under the Apache License, Version 2.0 (the "License");
7+
# you may not use this file except in compliance with the License.
8+
# You may obtain a copy of the License at
9+
#
10+
# http://www.apache.org/licenses/LICENSE-2.0
11+
#
12+
# Unless required by applicable law or agreed to in writing, software
13+
# distributed under the License is distributed on an "AS IS" BASIS,
14+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15+
# See the License for the specific language governing permissions and
16+
# limitations under the License.
17+
18+
19+
from .fixtures import * # necessary for pytest to discover the fixtures

tests/unit/fixtures/__init__.py

+19
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,19 @@
1+
# Copyright (c) "Neo4j"
2+
# Neo4j Sweden AB [http://neo4j.com]
3+
#
4+
# This file is part of Neo4j.
5+
#
6+
# Licensed under the Apache License, Version 2.0 (the "License");
7+
# you may not use this file except in compliance with the License.
8+
# You may obtain a copy of the License at
9+
#
10+
# http://www.apache.org/licenses/LICENSE-2.0
11+
#
12+
# Unless required by applicable law or agreed to in writing, software
13+
# distributed under the License is distributed on an "AS IS" BASIS,
14+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15+
# See the License for the specific language governing permissions and
16+
# limitations under the License.
17+
18+
19+
from ._fake_connection import *
+122
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,122 @@
1+
# Copyright (c) "Neo4j"
2+
# Neo4j Sweden AB [http://neo4j.com]
3+
#
4+
# This file is part of Neo4j.
5+
#
6+
# Licensed under the Apache License, Version 2.0 (the "License");
7+
# you may not use this file except in compliance with the License.
8+
# You may obtain a copy of the License at
9+
#
10+
# http://www.apache.org/licenses/LICENSE-2.0
11+
#
12+
# Unless required by applicable law or agreed to in writing, software
13+
# distributed under the License is distributed on an "AS IS" BASIS,
14+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15+
# See the License for the specific language governing permissions and
16+
# limitations under the License.
17+
18+
19+
import inspect
20+
21+
import pytest
22+
23+
from neo4j import ServerInfo
24+
from neo4j._deadline import Deadline
25+
26+
27+
__all__ = [
28+
"fake_connection",
29+
"fake_connection_generator",
30+
]
31+
32+
33+
@pytest.fixture
34+
def fake_connection_generator(session_mocker):
35+
mock = session_mocker.mock_module
36+
37+
class FakeConnection(mock.NonCallableMagicMock):
38+
callbacks = []
39+
server_info = ServerInfo("127.0.0.1", (4, 3))
40+
local_port = 1234
41+
bolt_patches = set()
42+
43+
def __init__(self, *args, **kwargs):
44+
super().__init__(*args, **kwargs)
45+
self.attach_mock(mock.Mock(return_value=True), "is_reset_mock")
46+
self.attach_mock(mock.Mock(return_value=False), "defunct")
47+
self.attach_mock(mock.Mock(return_value=False), "stale")
48+
self.attach_mock(mock.Mock(return_value=False), "closed")
49+
self.attach_mock(mock.Mock(return_value=False), "socket")
50+
self.socket.attach_mock(
51+
mock.Mock(return_value=None), "get_deadline"
52+
)
53+
54+
def set_deadline_side_effect(deadline):
55+
deadline = Deadline.from_timeout_or_deadline(deadline)
56+
self.socket.get_deadline.return_value = deadline
57+
58+
self.socket.attach_mock(
59+
mock.Mock(side_effect=set_deadline_side_effect), "set_deadline"
60+
)
61+
62+
def close_side_effect():
63+
self.closed.return_value = True
64+
65+
self.attach_mock(mock.Mock(side_effect=close_side_effect), "close")
66+
67+
@property
68+
def is_reset(self):
69+
if self.closed.return_value or self.defunct.return_value:
70+
raise AssertionError("is_reset should not be called on a closed or "
71+
"defunct connection.")
72+
return self.is_reset_mock()
73+
74+
def fetch_message(self, *args, **kwargs):
75+
if self.callbacks:
76+
cb = self.callbacks.pop(0)
77+
cb()
78+
return super().__getattr__("fetch_message")(*args, **kwargs)
79+
80+
def fetch_all(self, *args, **kwargs):
81+
while self.callbacks:
82+
cb = self.callbacks.pop(0)
83+
cb()
84+
return super().__getattr__("fetch_all")(*args, **kwargs)
85+
86+
def __getattr__(self, name):
87+
parent = super()
88+
89+
def build_message_handler(name):
90+
def func(*args, **kwargs):
91+
def callback():
92+
for cb_name, param_count in (
93+
("on_success", 1),
94+
("on_summary", 0)
95+
):
96+
cb = kwargs.get(cb_name, None)
97+
if callable(cb):
98+
try:
99+
param_count = \
100+
len(inspect.signature(cb).parameters)
101+
except ValueError:
102+
# e.g. built-in method as cb
103+
pass
104+
if param_count == 1:
105+
cb({})
106+
else:
107+
cb()
108+
self.callbacks.append(callback)
109+
110+
return func
111+
112+
method_mock = parent.__getattr__(name)
113+
if name in ("run", "commit", "pull", "rollback", "discard"):
114+
method_mock.side_effect = build_message_handler(name)
115+
return method_mock
116+
117+
return FakeConnection
118+
119+
120+
@pytest.fixture
121+
def fake_connection(fake_connection_generator):
122+
return fake_connection_generator()

tests/unit/io/test__common.py

-2
Original file line numberDiff line numberDiff line change
@@ -25,8 +25,6 @@
2525
ResetResponse,
2626
)
2727

28-
from ..work import fake_connection
29-
3028

3129
@pytest.mark.parametrize(("chunk_size", "data", "result"), (
3230
(

0 commit comments

Comments
 (0)