diff --git a/tests/stub/tx_run/test_tx_run.py b/tests/stub/tx_run/test_tx_run.py index c5e88fa6..bfd2b8b0 100644 --- a/tests/stub/tx_run/test_tx_run.py +++ b/tests/stub/tx_run/test_tx_run.py @@ -1,3 +1,5 @@ +import re + from nutkit import protocol as types from nutkit.frontend import Driver from tests.shared import ( @@ -276,8 +278,9 @@ def _test(): self._assert_is_client_exception(exc) # there must be no further PULL and an exception must be raised + original_exception = exc with self.assertRaises(types.DriverError) as exc: - if iterate == "true": + if iterate: for _i in range(0, 3): res.next() else: @@ -288,17 +291,14 @@ def _test(): # only explicit iteration is tested if fetch all is # not supported list(res) - # the streaming result surfaces the termination exception - self.assertEqual(exc.exception.code, - "Neo.ClientError.Statement.SyntaxError") - self._assert_is_client_exception(exc) + self._assert_is_failed_result_exception(exc, original_exception) tx.close() self._session.close() self._session = None self._server1.done() - for iterate in ["true", "false"]: + for iterate in (True, False): with self.subTest(iterate=iterate): _test() self._server1.reset() @@ -318,12 +318,10 @@ def test_should_prevent_discard_after_tx_termination_on_run(self): "Neo.ClientError.Statement.SyntaxError") self._assert_is_client_exception(exc) + original_exception = exc with self.assertRaises(types.DriverError) as exc: res.consume() - # the streaming result surfaces the termination exception - self.assertEqual(exc.exception.code, - "Neo.ClientError.Statement.SyntaxError") - self._assert_is_client_exception(exc) + self._assert_is_failed_result_exception(exc, original_exception) tx.close() self._session.close() @@ -343,7 +341,7 @@ def test_should_prevent_run_after_tx_termination_on_run(self): self._assert_is_client_exception(exc) with self.assertRaises(types.DriverError) as exc: - tx.run("invalid") + tx.run("RETURN 1 AS n") # new actions on the transaction result in a tx terminated # exception, a subclass of the client exception self._assert_is_tx_terminated_exception(exc) @@ -509,10 +507,16 @@ def _assert_is_client_exception(self, e): e.exception.errorType ) elif driver in ["python"]: - self.assertEqual( - "", - e.exception.errorType - ) + if e.exception.code.endswith(".SyntaxError"): + self.assertEqual( + "", + e.exception.errorType + ) + else: + self.assertEqual( + "", + e.exception.errorType + ) elif driver in ["go"]: self.assertEqual("Neo4jError", e.exception.errorType) self.assertIn("Neo.ClientError.", e.exception.msg) @@ -546,3 +550,19 @@ def _assert_is_tx_terminated_exception(self, e): ) else: self.fail("no error mapping is defined for %s driver" % driver) + + def _assert_is_failed_result_exception(self, e, original_exception): + driver = get_driver_name() + if driver in ["python"]: + self.assertEqual( + e.exception.errorType, + "" + ) + match = re.match(r"", + original_exception.exception.errorType) + self.assertIn(match.group(1), e.exception.msg) + + else: + self.assertEqual(e.exception.code, + original_exception.exception.code) + self._assert_is_client_exception(original_exception)