diff --git a/nutkit/frontend/__init__.py b/nutkit/frontend/__init__.py index 15d2f51a1..b267f5f7a 100644 --- a/nutkit/frontend/__init__.py +++ b/nutkit/frontend/__init__.py @@ -1 +1,2 @@ from .driver import Driver +from .exceptions import ApplicationCodeError diff --git a/nutkit/frontend/exceptions.py b/nutkit/frontend/exceptions.py new file mode 100644 index 000000000..cf44a07db --- /dev/null +++ b/nutkit/frontend/exceptions.py @@ -0,0 +1,2 @@ +class ApplicationCodeError(Exception): + pass diff --git a/nutkit/frontend/session.py b/nutkit/frontend/session.py index e13fc4084..48fd4a86b 100644 --- a/nutkit/frontend/session.py +++ b/nutkit/frontend/session.py @@ -1,12 +1,9 @@ from .. import protocol +from .exceptions import ApplicationCodeError from .result import Result from .transaction import Transaction -class ApplicationCodeError(Exception): - pass - - class Session: def __init__(self, driver, session): self._driver = driver diff --git a/tests/neo4j/test_tx_func_run.py b/tests/neo4j/test_tx_func_run.py index bfc11eb74..9da5035c8 100644 --- a/tests/neo4j/test_tx_func_run.py +++ b/tests/neo4j/test_tx_func_run.py @@ -1,4 +1,4 @@ -from nutkit.frontend.session import ApplicationCodeError +from nutkit.frontend import ApplicationCodeError import nutkit.protocol as types from tests.neo4j.shared import ( get_driver, @@ -115,9 +115,6 @@ def run(tx): self.assertGreater(len(bookmarks[0]), 3) def test_does_not_update_last_bookmark_on_rollback(self): - if get_driver_name() in ["java"]: - self.skipTest("Client exceptions not properly handled in backend") - # Verifies that last bookmarks still is empty when transactional # function rolls back transaction. def run(tx): @@ -125,20 +122,12 @@ def run(tx): raise ApplicationCodeError("No thanks") self._session1 = self._driver.session("w") - expected_exc = types.FrontendError - # TODO: remove this block once all languages work - if get_driver_name() in ["javascript"]: - expected_exc = types.DriverError - if get_driver_name() in ["dotnet"]: - expected_exc = types.BackendError - with self.assertRaises(expected_exc): + with self.assertRaises(types.FrontendError): self._session1.write_transaction(run) bookmarks = self._session1.last_bookmarks() self.assertEqual(len(bookmarks), 0) def test_client_exception_rolls_back_change(self): - if get_driver_name() in ["java"]: - self.skipTest("Client exceptions not properly handled in backend") node_id = -1 def run(tx): @@ -158,11 +147,6 @@ def assertion_query(tx): self._session1 = self._driver.session("w") expected_exc = types.FrontendError - # TODO: remove this block once all languages work - if get_driver_name() in ["javascript"]: - expected_exc = types.DriverError - if get_driver_name() in ["dotnet"]: - expected_exc = types.BackendError with self.assertRaises(expected_exc): self._session1.write_transaction(run) diff --git a/tests/stub/tx_lifetime/__init__.py b/tests/stub/tx_lifetime/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/tests/stub/tx_lifetime/scripts/v4x4/tx_inf_results_until_end.script b/tests/stub/tx_lifetime/scripts/v4x4/tx_inf_results_until_end.script new file mode 100644 index 000000000..ed941415b --- /dev/null +++ b/tests/stub/tx_lifetime/scripts/v4x4/tx_inf_results_until_end.script @@ -0,0 +1,32 @@ +!: BOLT 4.4 + +A: HELLO {"{}": "*"} +*: RESET +C: BEGIN {"{}": "*"} +S: SUCCESS {} +C: RUN {"U": "*"} {"{}": "*"} {"{}": "*"} +S: SUCCESS {"fields": ["n"]} +{* + C: PULL {"n": {"Z": "*"}, "[qid]": -1} + S: RECORD [1] + RECORD [2] + SUCCESS {"has_more": true} +*} +{{ + C: DISCARD {"n": -1, "[qid]": -1} + S: SUCCESS {"type": "r"} + {? + {{ + C: ROLLBACK + ---- + C: COMMIT + ---- + C: RESET + }} + S: SUCCESS {} + ?} +---- + A: RESET +}} +*: RESET +?: GOODBYE diff --git a/tests/stub/tx_lifetime/test_tx_lifetime.py b/tests/stub/tx_lifetime/test_tx_lifetime.py new file mode 100644 index 000000000..be2a332c4 --- /dev/null +++ b/tests/stub/tx_lifetime/test_tx_lifetime.py @@ -0,0 +1,122 @@ +from contextlib import contextmanager + +from nutkit.frontend import Driver +import nutkit.protocol as types +from tests.shared import ( + get_driver_name, + TestkitTestCase, +) +from tests.stub.shared import StubServer + + +class TestTxLifetime(TestkitTestCase): + def setUp(self): + super().setUp() + self._server = StubServer(9000) + + def tearDown(self): + # If test raised an exception this will make sure that the stub server + # is killed and it's output is dumped for analysis. + self._server.reset() + super().tearDown() + + @contextmanager + def _start_session(self, script): + uri = "bolt://%s" % self._server.address + driver = Driver(self._backend, uri, + types.AuthorizationToken("basic", principal="", + credentials="")) + self._server.start(path=self.script_path("v4x4", script)) + session = driver.session("r", fetch_size=2) + try: + yield session + finally: + session.close() + driver.close() + + def _asserts_tx_closed_error(self, exc): + driver = get_driver_name() + assert isinstance(exc, types.DriverError) + if driver in ["python"]: + self.assertEqual(exc.errorType, + "") + self.assertIn("closed", exc.msg.lower()) + elif driver in ["javascript", "go", "dotnet"]: + self.assertIn("transaction", exc.msg.lower()) + elif driver in ["java"]: + self.assertEqual(exc.errorType, + "org.neo4j.driver.exceptions.ClientException") + else: + self.fail("no error mapping is defined for %s driver" % driver) + + def _asserts_tx_managed_error(self, exc): + driver = get_driver_name() + if driver in ["python"]: + self.assertEqual(exc.errorType, "") + self.assertIn("managed", exc.msg.lower()) + elif driver in ["go"]: + self.assertIn("retryable transaction", exc.msg.lower()) + else: + self.fail("no error mapping is defined for %s driver" % driver) + + def _test_unmanaged_tx(self, first_action, second_action): + exc = None + script = "tx_inf_results_until_end.script" + with self._start_session(script) as session: + tx = session.begin_transaction() + res = tx.run("Query") + res.consume() + getattr(tx, first_action)() + if second_action == "close": + getattr(tx, second_action)() + elif second_action == "run": + with self.assertRaises(types.DriverError) as exc: + tx.run("Query").consume() + else: + with self.assertRaises(types.DriverError) as exc: + getattr(tx, second_action)() + + self._server.done() + self.assertEqual( + self._server.count_requests("ROLLBACK"), + int(first_action in ["rollback", "close"]) + ) + self.assertEqual( + self._server.count_requests("COMMIT"), + int(first_action == "commit") + ) + if exc is not None: + self._asserts_tx_closed_error(exc.exception) + + def test_unmanaged_tx_raises_tx_closed_exec(self): + for first_action in ("commit", "rollback", "close"): + for second_action in ("commit", "rollback", "close", "run"): + with self.subTest(first_action=first_action, + second_action=second_action): + self._test_unmanaged_tx(first_action, second_action) + self._server.reset() + + def _test_managed_tx(self, close_action): + def work(tx_): + res_ = tx_.run("Query") + res_.consume() + with self.assertRaises(types.DriverError) as exc_: + getattr(tx_, close_action)() + self._asserts_tx_managed_error(exc_.exception) + raise exc_.exception + + script = "tx_inf_results_until_end.script" + with self._start_session(script) as session: + with self.assertRaises(types.DriverError): + session.read_transaction(work) + + self._server.done() + self._server._dump() + self.assertEqual(self._server.count_requests("ROLLBACK"), 1) + self.assertEqual(self._server.count_requests("COMMIT"), 0) + + def test_managed_tx_raises_tx_managed_exec(self): + for close_action in ("commit", "rollback", "close"): + with self.subTest(close_action=close_action): + self._test_managed_tx(close_action) + self._server.reset()