Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
27 commits
Select commit Hold shift + click to select a range
3a8b844
Change scripts for home db caching
MaxAake Oct 17, 2024
4646f4e
Update reader_change_homedb.script
MaxAake Oct 22, 2024
513ac2b
update to tests for table cache
MaxAake Oct 24, 2024
f3a184d
test homedb cache
MaxAake Oct 29, 2024
d3662be
connection hint in homedb test
MaxAake Nov 1, 2024
b9a3fd1
Upgrade to bolt 5.7, move db to begin success
MaxAake Nov 1, 2024
da34b9e
WIP
robsdedude Nov 4, 2024
67bbd46
Merge branch '5.0' into homedb-cache-spike
robsdedude Nov 4, 2024
fd01942
Revert test-cases that should be unaffected
robsdedude Nov 4, 2024
f6ea08a
Refactor home db tests in prep for more home db cache tests
robsdedude Nov 4, 2024
db760ed
Add bolt 5.8 support to stub server
robsdedude Nov 5, 2024
a511f62
Improve auto response logging in bolt stub server
robsdedude Nov 5, 2024
5384082
Test direct driver doesn't pin db
robsdedude Nov 5, 2024
99bb485
fixed tests for older bolt
MaxAake Nov 7, 2024
ebb0bc0
Merge remote-tracking branch 'origin/homedb-cache-spike' into homedb-…
robsdedude Nov 8, 2024
bfc376a
Test home db cache is used for routing only
robsdedude Nov 8, 2024
7d8d62a
fix spacing for list alternatives tests
MaxAake Nov 11, 2024
9db1c2f
HomeDB cache routing: test different cache keys
robsdedude Nov 11, 2024
36db6d0
Add tests for cache key precedence
robsdedude Nov 11, 2024
71135f8
Test driver pins db even if server re-resolves it
robsdedude Nov 11, 2024
2aadd47
Rename test base class to avoid it being picked up as a test
robsdedude Nov 11, 2024
2282be6
Fix shared tests in base class being collected
robsdedude Nov 11, 2024
95eac84
Add test for RT changing underneath the cache
robsdedude Nov 11, 2024
7b1e814
Add feature flag and tests for basic auth principal optimization
robsdedude Nov 12, 2024
cf91370
allow concurrent for whoami test
MaxAake Nov 14, 2024
220fb8a
Minor clean-ups
RichardIrons-neo4j Dec 4, 2024
ee412b2
Code clean-up
robsdedude Dec 6, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
57 changes: 33 additions & 24 deletions boltstub/bolt_protocol.py
Original file line number Diff line number Diff line change
Expand Up @@ -221,7 +221,7 @@ def decode_versions(cls, b):

@classmethod
def get_auto_response(cls, request: TranslatedStructure):
if request.tag == b"\x01":
if request.name == "HELLO":
return TranslatedStructure(
"SUCCESS", b"\x70", {"server": cls.server_agent},
packstream_version=cls.packstream_version
Expand All @@ -246,7 +246,7 @@ class Bolt2Protocol(Bolt1Protocol):

@classmethod
def get_auto_response(cls, request: TranslatedStructure):
if request.tag == b"\x01":
if request.name == "HELLO":
return TranslatedStructure(
"SUCCESS", b"\x70", {"server": cls.server_agent},
packstream_version=cls.packstream_version
Expand Down Expand Up @@ -291,7 +291,7 @@ class Bolt3Protocol(Bolt2Protocol):

@classmethod
def get_auto_response(cls, request: TranslatedStructure):
if request.tag == b"\x01":
if request.name == "HELLO":
return TranslatedStructure(
"SUCCESS", b"\x70",
{
Expand Down Expand Up @@ -348,12 +348,12 @@ def decode_versions(cls, b):

@classmethod
def get_auto_response(cls, request: TranslatedStructure):
if request.tag == b"\x01":
if request.name == "HELLO":
return TranslatedStructure(
"SUCCESS", b"\x70",
{
"connection_id": next_auto_bolt_id(),
"server": cls.server_agent
"server": cls.server_agent,
},
packstream_version=cls.packstream_version
)
Expand Down Expand Up @@ -395,24 +395,6 @@ class Bolt4x1Protocol(Bolt4x0Protocol):

server_agent = "Neo4j/4.1.0"

@classmethod
def get_auto_response(cls, request: TranslatedStructure):
if request.tag == b"\x01":
return TranslatedStructure(
"SUCCESS", b"\x70",
{
"connection_id": next_auto_bolt_id(),
"server": cls.server_agent,
"routing": None,
},
packstream_version=cls.packstream_version
)
else:
return TranslatedStructure(
"SUCCESS", b"\x70", {},
packstream_version=cls.packstream_version
)


class Bolt4x2Protocol(Bolt4x1Protocol):

Expand Down Expand Up @@ -569,4 +551,31 @@ class Bolt5x7Protocol(Bolt5x6Protocol):
# allow the server to negotiate other bolt versions
equivalent_versions = set()

server_agent = "Neo4j/5.24.0"
server_agent = "Neo4j/5.26.0"


class Bolt5x8Protocol(Bolt5x7Protocol):
protocol_version = (5, 8)
version_aliases = set()
# allow the server to negotiate other bolt versions
equivalent_versions = set()

server_agent = "Neo4j/5.26.0"

@classmethod
def get_auto_response(cls, request: TranslatedStructure):
if request.name == "HELLO":
return TranslatedStructure(
"SUCCESS", b"\x70",
{
"connection_id": next_auto_bolt_id(),
"server": cls.server_agent,
"hints": {"ssr.enabled": True}
},
packstream_version=cls.packstream_version
)
else:
return TranslatedStructure(
"SUCCESS", b"\x70", {},
packstream_version=cls.packstream_version
)
15 changes: 6 additions & 9 deletions boltstub/channel.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,11 +109,6 @@ def send_raw(self, b):
self.wire.write(b)
self.wire.send()

def send_struct(self, struct):
self.log("S: %s", struct)
self.stream.write_message(struct)
self.stream.drain()

def send_server_line(self, server_line):
self.log("%s", server_line)
server_line = self.bolt_protocol.translate_server_line(server_line)
Expand All @@ -128,9 +123,9 @@ def _consume(self):
def consume(self, line_no=None):
if self._buffered_msg is not None:
if line_no is not None:
self.log("(%3i) C: %s", line_no, self._buffered_msg)
self.log("(%4i) C: %s", line_no, self._buffered_msg)
else:
self.log("(%3i) C: %s", self._buffered_msg)
self.log("(%4i) C: %s", self._buffered_msg)
msg = self._buffered_msg
self._buffered_msg = None
return msg
Expand Down Expand Up @@ -158,8 +153,10 @@ def assert_no_input(self):
)

def auto_respond(self, msg):
self.log("AUTO response:")
self.send_struct(self.bolt_protocol.get_auto_response(msg))
struct = self.bolt_protocol.get_auto_response(msg)
self.log("(AUTO) S: %s", struct)
self.stream.write_message(struct)
self.stream.drain()

def try_auto_consume(self, whitelist: Iterable[str]):
next_msg = self.peek()
Expand Down
73 changes: 47 additions & 26 deletions boltstub/parsing.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,7 +106,7 @@ def __new__(cls, line_number: int, raw_line, content: str):
return obj

def __str__(self):
return "({:3}) {}".format(self.line_number,
return "({:4}) {}".format(self.line_number,
super(Line, self).__str__())

def __repr__(self):
Expand Down Expand Up @@ -514,7 +514,7 @@ def consume(self, channel):
assert self.try_consume(channel)

@abc.abstractmethod
def has_deterministic_end(self):
def has_deterministic_end(self, channel=None) -> bool:
pass

@abc.abstractmethod
Expand Down Expand Up @@ -581,7 +581,7 @@ def _consume(self, channel):
def done(self, channel):
return self.index >= len(self.lines)

def has_deterministic_end(self) -> bool:
def has_deterministic_end(self, channel=None) -> bool:
return True

def init(self, channel):
Expand Down Expand Up @@ -659,7 +659,7 @@ def can_consume_after_reset(self, channel) -> bool:
def done(self, channel):
return self.index >= len(self.lines)

def has_deterministic_end(self) -> bool:
def has_deterministic_end(self, channel=None) -> bool:
return True

def init(self, channel):
Expand Down Expand Up @@ -762,8 +762,8 @@ def done(self, channel):
return (self.selection is not None
and self.block_lists[self.selection].done(channel))

def has_deterministic_end(self) -> bool:
return all(b.has_deterministic_end() for b in self.block_lists)
def has_deterministic_end(self, channel=None) -> bool:
return all(b.has_deterministic_end(channel) for b in self.block_lists)

def init(self, channel):
# self.assert_no_init()
Expand Down Expand Up @@ -834,8 +834,8 @@ def can_consume_after_reset(self, channel) -> bool:
return any(b.can_consume_after_reset(channel)
for b in self.block_lists)

def has_deterministic_end(self) -> bool:
return all(b.has_deterministic_end() for b in self.block_lists)
def has_deterministic_end(self, channel=None) -> bool:
return all(b.has_deterministic_end(channel) for b in self.block_lists)

def init(self, channel):
# self.assert_no_init()
Expand Down Expand Up @@ -889,7 +889,7 @@ def assert_no_init(self):

def can_be_skipped(self, channel):
if self.started:
if self.block_list.has_deterministic_end():
if self.block_list.has_deterministic_end(channel):
return self.block_list.done(channel)
return self.block_list.can_be_skipped(channel)
return True
Expand All @@ -901,14 +901,14 @@ def can_consume_after_reset(self, channel) -> bool:
return self.block_list.can_consume_after_reset(channel)

def done(self, channel) -> bool:
if self.started and self.block_list.has_deterministic_end():
if self.started and self.block_list.has_deterministic_end(channel):
return self.block_list.done(channel)
raise RuntimeError("it's nondeterministic!")

def has_deterministic_end(self) -> bool:
def has_deterministic_end(self, channel=None) -> bool:
if not self.started:
return False
return self.block_list.has_deterministic_end()
return self.block_list.has_deterministic_end(channel)

def init(self, channel):
# self.assert_no_init()
Expand Down Expand Up @@ -951,8 +951,10 @@ def __init__(self, block_list, line_number: int):
def accepted_messages(self, channel) -> List[ClientLine]:
res = OrderedDict((m, True)
for m in self.block_list.accepted_messages(channel))
if ((self.has_deterministic_end() and self.done(channel))
or self.block_list.can_be_skipped(channel)):
if (
(self.has_deterministic_end(channel) and self.done(channel))
or self.block_list.can_be_skipped(channel)
):
res.update(
(m, True)
for m in self.block_list.accepted_messages_after_reset(channel)
Expand Down Expand Up @@ -984,7 +986,7 @@ def _can_consume_nondeterministic(self, channel):
return False

def can_consume(self, channel) -> bool:
if self.block_list.has_deterministic_end():
if self.block_list.has_deterministic_end(channel):
return self._can_consume_deterministic(channel)
return self._can_consume_nondeterministic(channel)

Expand All @@ -994,7 +996,7 @@ def can_consume_after_reset(self, channel) -> bool:
def done(self, channel) -> bool:
raise RuntimeError("it's nondeterministic!")

def has_deterministic_end(self) -> bool:
def has_deterministic_end(self, channel=None) -> bool:
return False

def init(self, channel):
Expand Down Expand Up @@ -1030,7 +1032,7 @@ def _try_consume_nondeterministic(self, channel):
return False

def try_consume(self, channel) -> bool:
if self.block_list.has_deterministic_end():
if self.block_list.has_deterministic_end(channel):
return self._try_consume_deterministic(channel)
return self._try_consume_nondeterministic(channel)

Expand Down Expand Up @@ -1111,7 +1113,8 @@ def accepted_messages_after_reset(self, channel) -> List[ClientLine]:
return block.accepted_messages_after_reset(channel)

def assert_no_init(self):
return
for block in self.blocks:
block.assert_no_init()

def done(self, channel) -> bool:
block = self._probe_selection(channel, self.selection)
Expand All @@ -1138,8 +1141,17 @@ def can_consume_after_reset(self, channel) -> bool:
return block.can_consume_after_reset(channel)
pass

def has_deterministic_end(self):
return all(b.has_deterministic_end() for b in self.blocks)
def has_deterministic_end(self, channel=None) -> bool:
if channel is None:
if len(self.blocks) <= len(self.conditions):
# no else block => cannot guarantee deterministic end at static
# check time
return False
return all(b.has_deterministic_end() for b in self.blocks)
block = self._probe_selection(channel, self.selection)
if not block:
return True
return block.has_deterministic_end(channel)

def init(self, channel):
block = self._get_selection(channel, self.selection)
Expand Down Expand Up @@ -1233,18 +1245,24 @@ def can_consume_after_reset(self, channel) -> bool:
return False

def done(self, channel) -> bool:
if not self.has_deterministic_end():
if not self.has_deterministic_end(channel):
raise RuntimeError("it's nondeterministic!")
return self.index >= len(self.blocks)

def has_deterministic_end(self) -> bool:
return not self.blocks or self.blocks[-1].has_deterministic_end()
def has_deterministic_end(self, channel=None) -> bool:
return (
not self.blocks
or self.blocks[-1].has_deterministic_end(channel)
)

def init(self, channel):
while self.index < len(self.blocks):
block = self.blocks[self.index]
block.init(channel)
if not block.has_deterministic_end() or not block.done(channel):
if (
not block.has_deterministic_end(channel)
or not block.done(channel)
):
break
self.index += 1

Expand All @@ -1258,7 +1276,10 @@ def try_consume(self, channel) -> bool:
block = self.blocks[i]
if block.try_consume(channel):
self.index = i
while block.has_deterministic_end() and block.done(channel):
while (
block.has_deterministic_end(channel)
and block.done(channel)
):
self.index += 1
if self.index < len(self.blocks):
block = self.blocks[self.index]
Expand Down Expand Up @@ -1392,7 +1413,7 @@ def done(self, channel):
with self._lock:
if self._skipped:
return True
if self.block_list.has_deterministic_end():
if self.block_list.has_deterministic_end(channel):
return self.block_list.done(channel)
return False

Expand Down
10 changes: 5 additions & 5 deletions boltstub/tests/test_integration.py
Original file line number Diff line number Diff line change
Expand Up @@ -595,8 +595,8 @@ def test_lists_alternatives_on_unexpected_message(msg, restarting, concurrent,
assert len(server.service.exceptions) == 1
server_exc = server.service.exceptions[0]
line_offset = restarting + concurrent
assert "(%3i) C: RUN" % (5 + line_offset) in str(server_exc)
assert "(%3i) C: RESET" % (7 + line_offset) in str(server_exc)
assert "( %3i) C: RUN" % (5 + line_offset) in str(server_exc)
assert "( %3i) C: RESET" % (7 + line_offset) in str(server_exc)


# @pytest.mark.parametrize("block_marker", ("{?", "{*", "{+", "{{"))
Expand Down Expand Up @@ -625,11 +625,11 @@ def test_lists_alternatives_on_unexpected_message_with_non_det_block(
con.read(6)
assert len(server.service.exceptions) == 1
server_exc = server.service.exceptions[0]
assert "( 5) C: RUN" in str(server_exc)
assert "( 5) C: RUN" in str(server_exc)
if block_marker in ("{?", "{*"):
assert "( 7) C: RESET" in str(server_exc)
assert "( 7) C: RESET" in str(server_exc)
else:
assert "( 7) C: RESET" not in str(server_exc)
assert "( 7) C: RESET" not in str(server_exc)


def test_unknown_message(server_factory, connection_factory):
Expand Down
3 changes: 0 additions & 3 deletions boltstub/tests/test_script_logic.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,9 +77,6 @@ def match_client_line(self, client_line, msg):
def send_raw(self, b):
self.raw_buffer.extend(b)

def send_struct(self, struct):
self.msg_buffer.append(struct)

def send_server_line(self, server_line):
server_line.parse_jolt(self.jolt_package)
name, fields = server_line.jolt_parsed
Expand Down
6 changes: 6 additions & 0 deletions nutkit/protocol/feature.py
Original file line number Diff line number Diff line change
Expand Up @@ -125,6 +125,8 @@ class Feature(Enum):
BOLT_5_6 = "Feature:Bolt:5.6"
# The driver supports Bolt protocol version 5.7
BOLT_5_7 = "Feature:Bolt:5.7"
# The driver supports Bolt protocol version 5.8
BOLT_5_8 = "Feature:Bolt:5.8"
# The driver supports patching DateTimes to use UTC for Bolt 4.3 and 4.4
BOLT_PATCH_UTC = "Feature:Bolt:Patch:UTC"
# The driver supports impersonation
Expand Down Expand Up @@ -167,6 +169,10 @@ class Feature(Enum):
# sending BEGIN but pipelines the RUN and PULL right afterwards and
# consumes three messages after that. This saves 2 full round-trips.
OPT_EXECUTE_QUERY_PIPELINING = "Optimization:ExecuteQueryPipelining"
# The home db cache for optimistic home db resolution treats the principal
# in basic auth the exact same way it treats impersonated users.
OPT_HOME_DB_CACHE_BASIC_PRINCIPAL_IS_IMP_USER = \
"Optimization:HomeDbCacheBasicPrincipalIsImpersonatedUser"
# Driver doesn't explicitly send message data that is the default value.
# This conserves bandwidth.
OPT_IMPLICIT_DEFAULT_ARGUMENTS = "Optimization:ImplicitDefaultArguments"
Expand Down
Loading