From 2888e86a40885bb2024741a59064d4c356f946e9 Mon Sep 17 00:00:00 2001 From: Robsdedude Date: Thu, 5 Oct 2023 14:43:06 +0200 Subject: [PATCH] Adjust tx lifetime tests for Python Currently, the tests assert that the driver throws the same error with which the other result failed. I personally don't like it because that error has nothing to do with the result stream being used (but only the other failed stream). So instead I decided to throw some generic `ResultFailedError` in Python with the other error as context information for easier debugging. However, all of this is fool proofing of the drivers. Users will never encounter this error unless they're abusing the driver or doing really funky stuff with it. So I don't have a strong opinion about what exact error we should throw. The main focus should be on preventing the driver from sending invalid messages to the server and ideally throwing any error that's not absolutely cryptic. --- tests/stub/tx_run/test_tx_run.py | 50 ++++++++++++++++++++++---------- 1 file changed, 35 insertions(+), 15 deletions(-) diff --git a/tests/stub/tx_run/test_tx_run.py b/tests/stub/tx_run/test_tx_run.py index c5e88fa6..bfd2b8b0 100644 --- a/tests/stub/tx_run/test_tx_run.py +++ b/tests/stub/tx_run/test_tx_run.py @@ -1,3 +1,5 @@ +import re + from nutkit import protocol as types from nutkit.frontend import Driver from tests.shared import ( @@ -276,8 +278,9 @@ def _test(): self._assert_is_client_exception(exc) # there must be no further PULL and an exception must be raised + original_exception = exc with self.assertRaises(types.DriverError) as exc: - if iterate == "true": + if iterate: for _i in range(0, 3): res.next() else: @@ -288,17 +291,14 @@ def _test(): # only explicit iteration is tested if fetch all is # not supported list(res) - # the streaming result surfaces the termination exception - self.assertEqual(exc.exception.code, - "Neo.ClientError.Statement.SyntaxError") - self._assert_is_client_exception(exc) + self._assert_is_failed_result_exception(exc, original_exception) tx.close() self._session.close() self._session = None self._server1.done() - for iterate in ["true", "false"]: + for iterate in (True, False): with self.subTest(iterate=iterate): _test() self._server1.reset() @@ -318,12 +318,10 @@ def test_should_prevent_discard_after_tx_termination_on_run(self): "Neo.ClientError.Statement.SyntaxError") self._assert_is_client_exception(exc) + original_exception = exc with self.assertRaises(types.DriverError) as exc: res.consume() - # the streaming result surfaces the termination exception - self.assertEqual(exc.exception.code, - "Neo.ClientError.Statement.SyntaxError") - self._assert_is_client_exception(exc) + self._assert_is_failed_result_exception(exc, original_exception) tx.close() self._session.close() @@ -343,7 +341,7 @@ def test_should_prevent_run_after_tx_termination_on_run(self): self._assert_is_client_exception(exc) with self.assertRaises(types.DriverError) as exc: - tx.run("invalid") + tx.run("RETURN 1 AS n") # new actions on the transaction result in a tx terminated # exception, a subclass of the client exception self._assert_is_tx_terminated_exception(exc) @@ -509,10 +507,16 @@ def _assert_is_client_exception(self, e): e.exception.errorType ) elif driver in ["python"]: - self.assertEqual( - "", - e.exception.errorType - ) + if e.exception.code.endswith(".SyntaxError"): + self.assertEqual( + "", + e.exception.errorType + ) + else: + self.assertEqual( + "", + e.exception.errorType + ) elif driver in ["go"]: self.assertEqual("Neo4jError", e.exception.errorType) self.assertIn("Neo.ClientError.", e.exception.msg) @@ -546,3 +550,19 @@ def _assert_is_tx_terminated_exception(self, e): ) else: self.fail("no error mapping is defined for %s driver" % driver) + + def _assert_is_failed_result_exception(self, e, original_exception): + driver = get_driver_name() + if driver in ["python"]: + self.assertEqual( + e.exception.errorType, + "" + ) + match = re.match(r"", + original_exception.exception.errorType) + self.assertIn(match.group(1), e.exception.msg) + + else: + self.assertEqual(e.exception.code, + original_exception.exception.code) + self._assert_is_client_exception(original_exception)