Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 5 additions & 11 deletions neo4j/io/_bolt3.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@
CommitResponse,
InitResponse,
Response,
tx_timeout_as_ms,
)


Expand Down Expand Up @@ -225,11 +226,8 @@ def run(self, query, parameters=None, mode=None, bookmarks=None,
extra["tx_metadata"] = dict(metadata)
except TypeError:
raise TypeError("Metadata must be coercible to a dict")
if timeout:
try:
extra["tx_timeout"] = int(1000 * timeout)
except TypeError:
raise TypeError("Timeout must be specified as a number of seconds")
if timeout or timeout == 0:
extra["tx_timeout"] = tx_timeout_as_ms(timeout)
fields = (query, parameters, extra)
log.debug("[#%04X] C: RUN %s", self.local_port, " ".join(map(repr, fields)))
if query.upper() == u"COMMIT":
Expand Down Expand Up @@ -277,12 +275,8 @@ def begin(self, mode=None, bookmarks=None, metadata=None, timeout=None,
extra["tx_metadata"] = dict(metadata)
except TypeError:
raise TypeError("Metadata must be coercible to a dict")
if timeout:
try:
extra["tx_timeout"] = int(1000 * timeout)
except TypeError:
raise TypeError("Timeout must be specified as a number of seconds")
log.debug("[#%04X] C: BEGIN %r", self.local_port, extra)
if timeout or timeout == 0:
extra["tx_timeout"] = tx_timeout_as_ms(timeout)
self._append(b"\x11", (extra,), Response(self, "begin", **handlers))

def commit(self, **handlers):
Expand Down
32 changes: 9 additions & 23 deletions neo4j/io/_bolt4.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,6 @@
from neo4j.exceptions import (
ConfigurationError,
DatabaseUnavailable,
DriverError,
ForbiddenOnReadOnlyDatabase,
Neo4jError,
NotALeader,
Expand All @@ -48,6 +47,7 @@
CommitResponse,
InitResponse,
Response,
tx_timeout_as_ms,
)
from neo4j.io._bolt3 import (
ServerStateManager,
Expand Down Expand Up @@ -178,11 +178,8 @@ def run(self, query, parameters=None, mode=None, bookmarks=None,
extra["tx_metadata"] = dict(metadata)
except TypeError:
raise TypeError("Metadata must be coercible to a dict")
if timeout:
try:
extra["tx_timeout"] = int(1000 * timeout)
except TypeError:
raise TypeError("Timeout must be specified as a number of seconds")
if timeout or timeout == 0:
extra["tx_timeout"] = tx_timeout_as_ms(timeout)
fields = (query, parameters, extra)
log.debug("[#%04X] C: RUN %s", self.local_port, " ".join(map(repr, fields)))
if query.upper() == u"COMMIT":
Expand Down Expand Up @@ -229,11 +226,8 @@ def begin(self, mode=None, bookmarks=None, metadata=None, timeout=None,
extra["tx_metadata"] = dict(metadata)
except TypeError:
raise TypeError("Metadata must be coercible to a dict")
if timeout:
try:
extra["tx_timeout"] = int(1000 * timeout)
except TypeError:
raise TypeError("Timeout must be specified as a number of seconds")
if timeout or timeout == 0:
extra["tx_timeout"] = tx_timeout_as_ms(timeout)
log.debug("[#%04X] C: BEGIN %r", self.local_port, extra)
self._append(b"\x11", (extra,), Response(self, "begin", **handlers))

Expand Down Expand Up @@ -490,12 +484,8 @@ def run(self, query, parameters=None, mode=None, bookmarks=None,
extra["tx_metadata"] = dict(metadata)
except TypeError:
raise TypeError("Metadata must be coercible to a dict")
if timeout:
try:
extra["tx_timeout"] = int(1000 * timeout)
except TypeError:
raise TypeError("Timeout must be specified as a number of "
"seconds")
if timeout or timeout == 0:
extra["tx_timeout"] = tx_timeout_as_ms(timeout)
fields = (query, parameters, extra)
log.debug("[#%04X] C: RUN %s", self.local_port,
" ".join(map(repr, fields)))
Expand Down Expand Up @@ -525,11 +515,7 @@ def begin(self, mode=None, bookmarks=None, metadata=None, timeout=None,
extra["tx_metadata"] = dict(metadata)
except TypeError:
raise TypeError("Metadata must be coercible to a dict")
if timeout:
try:
extra["tx_timeout"] = int(1000 * timeout)
except TypeError:
raise TypeError("Timeout must be specified as a number of "
"seconds")
if timeout or timeout == 0:
extra["tx_timeout"] = tx_timeout_as_ms(timeout)
log.debug("[#%04X] C: BEGIN %r", self.local_port, extra)
self._append(b"\x11", (extra,), Response(self, "begin", **handlers))
30 changes: 30 additions & 0 deletions neo4j/io/_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -270,3 +270,33 @@ def on_failure(self, metadata):
class CommitResponse(Response):

pass


def tx_timeout_as_ms(timeout: float) -> int:
"""
Round transaction timeout to milliseconds.

Values in (0, 1], else values are rounded using the built-in round()
function (round n.5 values to nearest even).

:param timeout: timeout in seconds (must be >= 0)

:returns: timeout in milliseconds (rounded)

:raise ValueError: if timeout is negative
"""
try:
timeout = float(timeout)
except (TypeError, ValueError) as e:
err_type = type(e)
msg = "Timeout must be specified as a number of seconds"
raise err_type(msg) from None
if timeout < 0:
raise ValueError("Timeout must be a positive number or 0.")
ms = int(round(1000 * timeout))
if ms == 0 and timeout > 0:
# Special case for 0 < timeout < 0.5 ms.
# This would be rounded to 0 ms, but the server interprets this as
# infinite timeout. So we round to the smallest possible timeout: 1 ms.
ms = 1
return ms
57 changes: 57 additions & 0 deletions tests/unit/io/test_class_bolt3.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,3 +110,60 @@ def test_hint_recv_timeout_seconds_gets_ignored(fake_socket_pair, recv_timeout):
PoolConfig.max_connection_lifetime)
connection.hello()
sockets.client.settimeout.assert_not_called()


@pytest.mark.parametrize(
("func", "args", "extra_idx"),
(
("run", ("RETURN 1",), 2),
("begin", (), 0),
)
)
@pytest.mark.parametrize(
("timeout", "res"),
(
(None, None),
(0, 0),
(0.1, 100),
(0.001, 1),
(1e-15, 1),
(0.0005, 1),
(0.0001, 1),
(1.0015, 1002),
(1.000499, 1000),
(1.0025, 1002),
(3.0005, 3000),
(3.456, 3456),
(1, 1000),
(
-1e-15,
ValueError("Timeout must be a positive number or 0")
),
(
"foo",
ValueError("Timeout must be specified as a number of seconds")
),
(
[1, 2],
TypeError("Timeout must be specified as a number of seconds")
)
)
)
def test_tx_timeout(fake_socket_pair, func, args, extra_idx, timeout, res):
address = ("127.0.0.1", 7687)
sockets = fake_socket_pair(address)
sockets.server.send_message(0x70, {})
connection = Bolt3(address, sockets.client, 0)
func = getattr(connection, func)
if isinstance(res, Exception):
with pytest.raises(type(res), match=str(res)):
func(*args, timeout=timeout)
else:
func(*args, timeout=timeout)
connection.send_all()
tag, fields = sockets.server.pop_message()
extra = fields[extra_idx]
if timeout is None:
assert "tx_timeout" not in extra
else:
assert extra["tx_timeout"] == res
57 changes: 57 additions & 0 deletions tests/unit/io/test_class_bolt4x0.py
Original file line number Diff line number Diff line change
Expand Up @@ -197,3 +197,60 @@ def test_hint_recv_timeout_seconds_gets_ignored(fake_socket_pair, recv_timeout):
PoolConfig.max_connection_lifetime)
connection.hello()
sockets.client.settimeout.assert_not_called()


@pytest.mark.parametrize(
("func", "args", "extra_idx"),
(
("run", ("RETURN 1",), 2),
("begin", (), 0),
)
)
@pytest.mark.parametrize(
("timeout", "res"),
(
(None, None),
(0, 0),
(0.1, 100),
(0.001, 1),
(1e-15, 1),
(0.0005, 1),
(0.0001, 1),
(1.0015, 1002),
(1.000499, 1000),
(1.0025, 1002),
(3.0005, 3000),
(3.456, 3456),
(1, 1000),
(
-1e-15,
ValueError("Timeout must be a positive number or 0")
),
(
"foo",
ValueError("Timeout must be specified as a number of seconds")
),
(
[1, 2],
TypeError("Timeout must be specified as a number of seconds")
)
)
)
def test_tx_timeout(fake_socket_pair, func, args, extra_idx, timeout, res):
address = ("127.0.0.1", 7687)
sockets = fake_socket_pair(address)
sockets.server.send_message(0x70, {})
connection = Bolt4x0(address, sockets.client, 0)
func = getattr(connection, func)
if isinstance(res, Exception):
with pytest.raises(type(res), match=str(res)):
func(*args, timeout=timeout)
else:
func(*args, timeout=timeout)
connection.send_all()
tag, fields = sockets.server.pop_message()
extra = fields[extra_idx]
if timeout is None:
assert "tx_timeout" not in extra
else:
assert extra["tx_timeout"] == res
57 changes: 57 additions & 0 deletions tests/unit/io/test_class_bolt4x1.py
Original file line number Diff line number Diff line change
Expand Up @@ -210,3 +210,60 @@ def test_hint_recv_timeout_seconds_gets_ignored(fake_socket_pair, recv_timeout):
PoolConfig.max_connection_lifetime)
connection.hello()
sockets.client.settimeout.assert_not_called()


@pytest.mark.parametrize(
("func", "args", "extra_idx"),
(
("run", ("RETURN 1",), 2),
("begin", (), 0),
)
)
@pytest.mark.parametrize(
("timeout", "res"),
(
(None, None),
(0, 0),
(0.1, 100),
(0.001, 1),
(1e-15, 1),
(0.0005, 1),
(0.0001, 1),
(1.0015, 1002),
(1.000499, 1000),
(1.0025, 1002),
(3.0005, 3000),
(3.456, 3456),
(1, 1000),
(
-1e-15,
ValueError("Timeout must be a positive number or 0")
),
(
"foo",
ValueError("Timeout must be specified as a number of seconds")
),
(
[1, 2],
TypeError("Timeout must be specified as a number of seconds")
)
)
)
def test_tx_timeout(fake_socket_pair, func, args, extra_idx, timeout, res):
address = ("127.0.0.1", 7687)
sockets = fake_socket_pair(address)
sockets.server.send_message(0x70, {})
connection = Bolt4x1(address, sockets.client, 0)
func = getattr(connection, func)
if isinstance(res, Exception):
with pytest.raises(type(res), match=str(res)):
func(*args, timeout=timeout)
else:
func(*args, timeout=timeout)
connection.send_all()
tag, fields = sockets.server.pop_message()
extra = fields[extra_idx]
if timeout is None:
assert "tx_timeout" not in extra
else:
assert extra["tx_timeout"] == res
57 changes: 57 additions & 0 deletions tests/unit/io/test_class_bolt4x2.py
Original file line number Diff line number Diff line change
Expand Up @@ -211,3 +211,60 @@ def test_hint_recv_timeout_seconds_gets_ignored(fake_socket_pair, recv_timeout):
PoolConfig.max_connection_lifetime)
connection.hello()
sockets.client.settimeout.assert_not_called()


@pytest.mark.parametrize(
("func", "args", "extra_idx"),
(
("run", ("RETURN 1",), 2),
("begin", (), 0),
)
)
@pytest.mark.parametrize(
("timeout", "res"),
(
(None, None),
(0, 0),
(0.1, 100),
(0.001, 1),
(1e-15, 1),
(0.0005, 1),
(0.0001, 1),
(1.0015, 1002),
(1.000499, 1000),
(1.0025, 1002),
(3.0005, 3000),
(3.456, 3456),
(1, 1000),
(
-1e-15,
ValueError("Timeout must be a positive number or 0")
),
(
"foo",
ValueError("Timeout must be specified as a number of seconds")
),
(
[1, 2],
TypeError("Timeout must be specified as a number of seconds")
)
)
)
def test_tx_timeout(fake_socket_pair, func, args, extra_idx, timeout, res):
address = ("127.0.0.1", 7687)
sockets = fake_socket_pair(address)
sockets.server.send_message(0x70, {})
connection = Bolt4x2(address, sockets.client, 0)
func = getattr(connection, func)
if isinstance(res, Exception):
with pytest.raises(type(res), match=str(res)):
func(*args, timeout=timeout)
else:
func(*args, timeout=timeout)
connection.send_all()
tag, fields = sockets.server.pop_message()
extra = fields[extra_idx]
if timeout is None:
assert "tx_timeout" not in extra
else:
assert extra["tx_timeout"] == res
Loading