From 6513f37146d40da6b532ed4e6ffd0712509e72f4 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?F=C3=A1bio=20C=2E=20Barrionuevo=20da=20Luz?= Date: Thu, 25 Apr 2019 18:09:41 -0300 Subject: [PATCH] Add support for serialized rollback --- .gitignore | 1 + docs/changelog.rst | 11 +++++++++++ docs/helpers.rst | 28 +++++++++++++++++++++++++++- pytest_django/fixtures.py | 26 ++++++++++++++++++++++++-- pytest_django/plugin.py | 21 +++++++++++++-------- tests/test_database.py | 15 ++++++++++++++- 6 files changed, 90 insertions(+), 12 deletions(-) diff --git a/.gitignore b/.gitignore index 90131acf9..5279d25c3 100644 --- a/.gitignore +++ b/.gitignore @@ -16,3 +16,4 @@ _build .Python .eggs *.egg +.idea/ diff --git a/docs/changelog.rst b/docs/changelog.rst index d8c2350cf..d5b32235c 100644 --- a/docs/changelog.rst +++ b/docs/changelog.rst @@ -1,6 +1,17 @@ Changelog ========= +NEXT +---- + +Features +^^^^^^^^ +* Add support for serialized rollback in transactional tests. (#721) + Thanks to Piotr Karkut for `the bug report + `_. + + + 3.6.0 (2019-10-17) ------------------ diff --git a/docs/helpers.rst b/docs/helpers.rst index 1685b70d0..c0b4f596a 100644 --- a/docs/helpers.rst +++ b/docs/helpers.rst @@ -16,7 +16,7 @@ on what marks are and for notes on using_ them. ``pytest.mark.django_db`` - request database access ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ -.. :py:function:: pytest.mark.django_db([transaction=False, reset_sequences=False]): +.. :py:function:: pytest.mark.django_db([transaction=False, reset_sequences=False, serialized_rollback=False]): This is used to mark a test function as requiring the database. It will ensure the database is set up correctly for the test. Each test @@ -47,6 +47,14 @@ test will fail when trying to access the database. effect. Please be aware that not all databases support this feature. For details see :py:attr:`django.test.TransactionTestCase.reset_sequences`. +:type serialized_rollback: bool +:param serialized_rollback: + The ``serialized_rollback`` argument enables `rollback emulation`_. + After a `django.test.TransactionTestCase`_ runs, the database is + flushed, destroying data created in data migrations. This is the + default behavior of Django. Setting ``serialized_rollback=True`` + tells Django to restore that data. + .. note:: If you want access to the Django database *inside a fixture* @@ -63,6 +71,7 @@ test will fail when trying to access the database. Test classes that subclass Python's ``unittest.TestCase`` need to have the marker applied in order to access the database. +.. _rollback emulation: https://docs.djangoproject.com/en/stable/topics/testing/overview/#rollback-emulation .. _django.test.TestCase: https://docs.djangoproject.com/en/dev/topics/testing/overview/#testcase .. _django.test.TransactionTestCase: https://docs.djangoproject.com/en/dev/topics/testing/overview/#transactiontestcase @@ -242,6 +251,17 @@ sequences (if your database supports it). This is only required for fixtures which need database access themselves. A test function should normally use the ``pytest.mark.django_db`` mark with ``transaction=True`` and ``reset_sequences=True``. +``django_db_serialized_rollback`` +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +When the ``transactional_db`` fixture is enabled, this fixture can be +added to trigger `rollback emulation`_ and thus restores data created +in data migrations after each transaction test. This is only required +for fixtures which need to enforce this behavior. A test function +would use ``pytest.mark.django_db(serialized_rollback=True)`` +to request this behavior. + + ``live_server`` ~~~~~~~~~~~~~~~ @@ -251,6 +271,12 @@ or by requesting it's string value: ``unicode(live_server)``. You can also directly concatenate a string to form a URL: ``live_server + '/foo``. +Since the live server and the tests run in different threads, they +cannot share a database transaction. For this reason, ``live_server`` +depends on the ``transactional_db`` fixture. If tests depend on data +created in data migrations, you should add the ``serialized_rollback`` +fixture. + .. note:: Combining database access fixtures. When using multiple database fixtures together, only one of them is diff --git a/pytest_django/fixtures.py b/pytest_django/fixtures.py index 519522c80..f0a7a2a5f 100644 --- a/pytest_django/fixtures.py +++ b/pytest_django/fixtures.py @@ -18,6 +18,7 @@ "db", "transactional_db", "django_db_reset_sequences", + "django_db_serialized_rollback", "admin_user", "django_user_model", "django_username_field", @@ -124,7 +125,8 @@ def teardown_database(): def _django_db_fixture_helper( - request, django_db_blocker, transactional=False, reset_sequences=False + request, django_db_blocker, transactional=False, reset_sequences=False, + serialized_rollback=False ): if is_django_unittest(request): return @@ -149,6 +151,7 @@ class ResetSequenceTestCase(django_case): from django.test import TestCase as django_case test_case = django_case(methodName="__init__") + test_case.serialized_rollback = serialized_rollback test_case._pre_setup() request.addfinalizer(test_case._post_teardown) @@ -207,13 +210,16 @@ def db(request, django_db_setup, django_db_blocker): """ if "django_db_reset_sequences" in request.fixturenames: request.getfixturevalue("django_db_reset_sequences") + if "django_db_serialized_rollback" in request.fixturenames: + request.getfixturevalue("django_db_serialized_rollback") if ( "transactional_db" in request.fixturenames or "live_server" in request.fixturenames ): request.getfixturevalue("transactional_db") else: - _django_db_fixture_helper(request, django_db_blocker, transactional=False) + _django_db_fixture_helper(request, django_db_blocker, transactional=False, + serialized_rollback=False) @pytest.fixture(scope="function") @@ -232,6 +238,8 @@ def transactional_db(request, django_db_setup, django_db_blocker): """ if "django_db_reset_sequences" in request.fixturenames: request.getfixturevalue("django_db_reset_sequences") + if "django_db_serialized_rollback" in request.fixturenames: + request.getfixturevalue("django_db_serialized_rollback") _django_db_fixture_helper(request, django_db_blocker, transactional=True) @@ -253,6 +261,20 @@ def django_db_reset_sequences(request, django_db_setup, django_db_blocker): ) +@pytest.fixture(scope="function") +def django_db_serialized_rollback(request, django_db_setup, django_db_blocker): + """Enable serialized rollback after transaction test cases + + This fixture only has an effect when the ``transactional_db`` + fixture is active, which happen as a side-effect of requesting + ``live_server``. + + """ + _django_db_fixture_helper( + request, django_db_blocker, transactional=True, serialized_rollback=True + ) + + @pytest.fixture() def client(): """A Django test client instance.""" diff --git a/pytest_django/plugin.py b/pytest_django/plugin.py index 898257cb5..ebfed3b27 100644 --- a/pytest_django/plugin.py +++ b/pytest_django/plugin.py @@ -34,6 +34,7 @@ from .fixtures import django_username_field # noqa from .fixtures import live_server # noqa from .fixtures import django_db_reset_sequences # noqa +from .fixtures import django_db_serialized_rollback # noqa from .fixtures import rf # noqa from .fixtures import settings # noqa from .fixtures import transactional_db # noqa @@ -497,14 +498,17 @@ def django_db_blocker(): def _django_db_marker(request): """Implement the django_db marker, internal to pytest-django. - This will dynamically request the ``db``, ``transactional_db`` or - ``django_db_reset_sequences`` fixtures as required by the django_db marker. + This will dynamically request the ``db``, ``transactional_db``, + ``django_db_reset_sequences`` or ``django_db_serialized_rollback`` + fixtures as required by the django_db marker. """ marker = request.node.get_closest_marker("django_db") if marker: - transaction, reset_sequences = validate_django_db(marker) + transaction, reset_sequences, serialized_rollback = validate_django_db(marker) if reset_sequences: request.getfixturevalue("django_db_reset_sequences") + elif serialized_rollback: + request.getfixturevalue("django_db_serialized_rollback") elif transaction: request.getfixturevalue("transactional_db") else: @@ -805,15 +809,16 @@ def restore(self): def validate_django_db(marker): """Validate the django_db marker. - It checks the signature and creates the ``transaction`` and - ``reset_sequences`` attributes on the marker which will have the - correct values. + It checks the signature and creates the ``transaction``, + ``reset_sequences`` and ``serialized_rollback`` attributes on + the marker which will have the correct values. A sequence reset is only allowed when combined with a transaction. + A serialized rollback is only allowed when combined with a transaction. """ - def apifun(transaction=False, reset_sequences=False): - return transaction, reset_sequences + def apifun(transaction=False, reset_sequences=False, serialized_rollback=False): + return transaction, reset_sequences, serialized_rollback return apifun(*marker.args, **marker.kwargs) diff --git a/tests/test_database.py b/tests/test_database.py index 7bcb06289..c9ca19ff5 100644 --- a/tests/test_database.py +++ b/tests/test_database.py @@ -50,10 +50,13 @@ def non_zero_sequences_counter(db): class TestDatabaseFixtures: """Tests for the different database fixtures.""" - @pytest.fixture(params=["db", "transactional_db", "django_db_reset_sequences"]) + @pytest.fixture(params=["db", "transactional_db", "django_db_reset_sequences", + "django_db_serialized_rollback"]) def all_dbs(self, request): if request.param == "django_db_reset_sequences": return request.getfixturevalue("django_db_reset_sequences") + elif request.param == "django_db_serialized_rollback": + return request.getfixturevalue("django_db_serialized_rollback") elif request.param == "transactional_db": return request.getfixturevalue("transactional_db") elif request.param == "db": @@ -215,6 +218,16 @@ def test_reset_sequences_enabled(self, request): marker = request.node.get_closest_marker("django_db") assert marker.kwargs["reset_sequences"] + @pytest.mark.django_db + def test_serialized_rollback_disabled(self, request): + marker = request.node.get_closest_marker("django_db") + assert not marker.kwargs + + @pytest.mark.django_db(serialized_rollback=True) + def test_serialized_rollback_enabled(self, request): + marker = request.node.get_closest_marker("django_db") + assert marker.kwargs["serialized_rollback"] + def test_unittest_interaction(django_testdir): "Test that (non-Django) unittests cannot access the DB."