diff --git a/docs/changelog.rst b/docs/changelog.rst index 5b04c3ba..868e20ea 100644 --- a/docs/changelog.rst +++ b/docs/changelog.rst @@ -1,6 +1,16 @@ Changelog ========= +<<<<<<< HEAD +NEXT +---- + +Features +^^^^^^^^ +* Add support for serialized rollback in transactional tests. (#721) + Thanks to Piotr Karkut for `the bug report + `_. +======= v4.4.0 (2021-06-06) ------------------- @@ -40,6 +50,7 @@ Bugfixes ^^^^^^^^ * Disable atomic durability check on non-transactional tests (#910). +>>>>>>> master v4.1.0 (2020-10-22) diff --git a/docs/helpers.rst b/docs/helpers.rst index 774237b3..b3f21ad4 100644 --- a/docs/helpers.rst +++ b/docs/helpers.rst @@ -26,7 +26,11 @@ dynamically in a hook or fixture. ``pytest.mark.django_db`` - request database access ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ +<<<<<<< HEAD +.. py:function:: pytest.mark.django_db([transaction=False, reset_sequences=False, serialized_rollback=False]) +======= .. py:function:: pytest.mark.django_db([transaction=False, reset_sequences=False, databases=None]) +>>>>>>> master This is used to mark a test function as requiring the database. It will ensure the database is set up correctly for the test. Each test @@ -56,6 +60,15 @@ dynamically in a hook or fixture. effect. Please be aware that not all databases support this feature. For details see :py:attr:`django.test.TransactionTestCase.reset_sequences`. +<<<<<<< HEAD + :type serialized_rollback: bool + :param serialized_rollback: + The ``serialized_rollback`` argument enables `rollback emulation`_. + After a `django.test.TransactionTestCase`_ runs, the database is + flushed, destroying data created in data migrations. This is the + default behavior of Django. Setting ``serialized_rollback=True`` + tells Django to restore that data. +======= :type databases: Union[Iterable[str], str, None] :param databases: @@ -72,6 +85,7 @@ dynamically in a hook or fixture. to specify all configured databases. For details see :py:attr:`django.test.TransactionTestCase.databases` and :py:attr:`django.test.TestCase.databases`. +>>>>>>> master .. note:: @@ -88,7 +102,11 @@ dynamically in a hook or fixture. Test classes that subclass :class:`django.test.TestCase` will have access to the database always to make them compatible with existing Django tests. Test classes that subclass Python's :class:`unittest.TestCase` need to have - the marker applied in order to access the database. + marker applied in order to access the database. + +.. _rollback emulation: https://docs.djangoproject.com/en/stable/topics/testing/overview/#rollback-emulation +.. _django.test.TestCase: https://docs.djangoproject.com/en/dev/topics/testing/overview/#testcase +.. _django.test.TransactionTestCase: https://docs.djangoproject.com/en/dev/topics/testing/overview/#transactiontestcase ``pytest.mark.urls`` - override the urlconf @@ -333,6 +351,17 @@ use the :func:`pytest.mark.django_db` mark with ``transaction=True`` and .. fixture:: live_server +``django_db_serialized_rollback`` +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +When the ``transactional_db`` fixture is enabled, this fixture can be +added to trigger `rollback emulation`_ and thus restores data created +in data migrations after each transaction test. This is only required +for fixtures which need to enforce this behavior. A test function +would use ``pytest.mark.django_db(serialized_rollback=True)`` +to request this behavior. + + ``live_server`` ~~~~~~~~~~~~~~~ @@ -342,6 +371,12 @@ or by requesting it's string value: ``str(live_server)``. You can also directly concatenate a string to form a URL: ``live_server + '/foo'``. +Since the live server and the tests run in different threads, they +cannot share a database transaction. For this reason, ``live_server`` +depends on the ``transactional_db`` fixture. If tests depend on data +created in data migrations, you should add the ``serialized_rollback`` +fixture. + .. note:: Combining database access fixtures. When using multiple database fixtures together, only one of them is diff --git a/pytest_django/fixtures.py b/pytest_django/fixtures.py index 7fcfe679..838bb48d 100644 --- a/pytest_django/fixtures.py +++ b/pytest_django/fixtures.py @@ -3,7 +3,14 @@ from contextlib import contextmanager from functools import partial from typing import ( - Any, Callable, Generator, Iterable, List, Optional, Tuple, Union, + Any, + Callable, + Generator, + Iterable, + List, + Optional, + Tuple, + Union, ) import pytest @@ -28,6 +35,7 @@ "db", "transactional_db", "django_db_reset_sequences", + "django_db_serialized_rollback", "admin_user", "django_user_model", "django_username_field", @@ -143,6 +151,7 @@ def _django_db_fixture_helper( django_db_blocker, transactional: bool = False, reset_sequences: bool = False, + serialized_rollback: bool = False, ) -> None: if is_django_unittest(request): return @@ -152,7 +161,9 @@ def _django_db_fixture_helper( return _databases = getattr( - request.node, "_pytest_django_databases", None, + request.node, + "_pytest_django_databases", + None, ) # type: Optional[_DjangoDbDatabases] django_db_blocker.unblock() @@ -163,6 +174,21 @@ def _django_db_fixture_helper( if transactional: test_case_class = django.test.TransactionTestCase + + if reset_sequences: + + class ResetSequenceTestCase(test_case_class): + reset_sequences = True + + test_case_class = ResetSequenceTestCase + + if serialized_rollback: + + class SerializedRollbackTestCase(test_case_class): + serialized_rollback = True + + test_case_class = SerializedRollbackTestCase + else: test_case_class = django.test.TestCase @@ -245,13 +271,17 @@ def db( """ if "django_db_reset_sequences" in request.fixturenames: request.getfixturevalue("django_db_reset_sequences") + if "django_db_serialized_rollback" in request.fixturenames: + request.getfixturevalue("django_db_serialized_rollback") if ( "transactional_db" in request.fixturenames or "live_server" in request.fixturenames ): request.getfixturevalue("transactional_db") else: - _django_db_fixture_helper(request, django_db_blocker, transactional=False) + _django_db_fixture_helper( + request, django_db_blocker, transactional=False, serialized_rollback=False + ) @pytest.fixture(scope="function") @@ -274,6 +304,8 @@ def transactional_db( """ if "django_db_reset_sequences" in request.fixturenames: request.getfixturevalue("django_db_reset_sequences") + if "django_db_serialized_rollback" in request.fixturenames: + request.getfixturevalue("django_db_serialized_rollback") _django_db_fixture_helper(request, django_db_blocker, transactional=True) @@ -299,6 +331,20 @@ def django_db_reset_sequences( ) +@pytest.fixture(scope="function") +def django_db_serialized_rollback(request, django_db_setup, django_db_blocker): + """Enable serialized rollback after transaction test cases + + This fixture only has an effect when the ``transactional_db`` + fixture is active, which happen as a side-effect of requesting + ``live_server``. + + """ + _django_db_fixture_helper( + request, django_db_blocker, transactional=True, serialized_rollback=True + ) + + @pytest.fixture() def client() -> "django.test.client.Client": """A Django test client instance.""" @@ -462,9 +508,11 @@ def live_server(request): """ skip_if_no_django() - addr = request.config.getvalue("liveserver") or os.getenv( - "DJANGO_LIVE_TEST_SERVER_ADDRESS" - ) or "localhost" + addr = ( + request.config.getvalue("liveserver") + or os.getenv("DJANGO_LIVE_TEST_SERVER_ADDRESS") + or "localhost" + ) server = live_server_helper.LiveServer(addr) request.addfinalizer(server.stop) @@ -549,11 +597,7 @@ def django_assert_max_num_queries(pytestconfig): @contextmanager -def _capture_on_commit_callbacks( - *, - using: Optional[str] = None, - execute: bool = False -): +def _capture_on_commit_callbacks(*, using: Optional[str] = None, execute: bool = False): from django.db import DEFAULT_DB_ALIAS, connections from django.test import TestCase @@ -574,7 +618,9 @@ def _capture_on_commit_callbacks( callback() else: - with TestCase.captureOnCommitCallbacks(using=using, execute=execute) as callbacks: + with TestCase.captureOnCommitCallbacks( + using=using, execute=execute + ) as callbacks: yield callbacks diff --git a/pytest_django/plugin.py b/pytest_django/plugin.py index b006305c..41d5cb23 100644 --- a/pytest_django/plugin.py +++ b/pytest_django/plugin.py @@ -15,7 +15,7 @@ import pytest from .django_compat import is_django_unittest # noqa -from .fixtures import _live_server_helper # noqa +from .fixtures import _live_server_helper # noqa; noqa from .fixtures import admin_client # noqa from .fixtures import admin_user # noqa from .fixtures import async_client # noqa @@ -40,6 +40,7 @@ from .fixtures import rf # noqa from .fixtures import settings # noqa from .fixtures import transactional_db # noqa +from .fixtures import django_db_serialized_rollback from .lazy_django import django_settings_is_configured, skip_if_no_django @@ -380,7 +381,7 @@ def get_order_number(test: pytest.Item) -> int: if issubclass(test_cls, TransactionTestCase): return 1 - marker_db = test.get_closest_marker('django_db') + marker_db = test.get_closest_marker("django_db") if not marker_db: transaction = None else: @@ -388,7 +389,7 @@ def get_order_number(test: pytest.Item) -> int: if transaction is True: return 1 - fixtures = getattr(test, 'fixturenames', []) + fixtures = getattr(test, "fixturenames", []) if "transactional_db" in fixtures: return 1 @@ -417,7 +418,8 @@ def django_test_environment(request) -> None: if django_settings_is_configured(): _setup_django() from django.test.utils import ( - setup_test_environment, teardown_test_environment, + setup_test_environment, + teardown_test_environment, ) debug_ini = request.config.getini("django_debug_mode") @@ -454,18 +456,26 @@ def django_db_blocker() -> "Optional[_DatabaseBlocker]": def _django_db_marker(request) -> None: """Implement the django_db marker, internal to pytest-django. - This will dynamically request the ``db``, ``transactional_db`` or - ``django_db_reset_sequences`` fixtures as required by the django_db marker. + This will dynamically request the ``db``, ``transactional_db``, + ``django_db_reset_sequences`` or ``django_db_serialized_rollback`` + fixtures as required by the django_db marker. """ marker = request.node.get_closest_marker("django_db") if marker: - transaction, reset_sequences, databases = validate_django_db(marker) + ( + transaction, + reset_sequences, + serialized_rollback, + databases, + ) = validate_django_db(marker) # TODO: Use pytest Store (item.store) once that's stable. request.node._pytest_django_databases = databases if reset_sequences: request.getfixturevalue("django_db_reset_sequences") + elif serialized_rollback: + request.getfixturevalue("django_db_serialized_rollback") elif transaction: request.getfixturevalue("transactional_db") else: @@ -486,6 +496,7 @@ def _django_setup_unittest( # Before pytest 5.4: https://github.com/pytest-dev/pytest/issues/5991 # After pytest 5.4: https://github.com/pytest-dev/pytest-django/issues/824 from _pytest.unittest import TestCaseFunction + original_runtest = TestCaseFunction.runtest def non_debugging_runtest(self) -> None: @@ -641,13 +652,15 @@ def __mod__(self, var: str) -> str: from django.conf import settings as dj_settings if dj_settings.TEMPLATES: - dj_settings.TEMPLATES[0]["OPTIONS"]["string_if_invalid"] = InvalidVarException() + dj_settings.TEMPLATES[0]["OPTIONS"][ + "string_if_invalid" + ] = InvalidVarException() @pytest.fixture(autouse=True) def _template_string_if_invalid_marker(request) -> None: """Apply the @pytest.mark.ignore_template_errors marker, - internal to pytest-django.""" + internal to pytest-django.""" marker = request.keywords.get("ignore_template_errors", None) if os.environ.get(INVALID_TEMPLATE_VARS_ENV, "false") == "true": if marker and django_settings_is_configured(): @@ -742,18 +755,20 @@ def validate_django_db(marker) -> "_DjangoDb": """Validate the django_db marker. It checks the signature and creates the ``transaction``, - ``reset_sequences`` and ``databases`` attributes on the marker - which will have the correct values. + ``reset_sequences`` and ``serialized_rollback`` attributes on + the marker which will have the correct values. A sequence reset is only allowed when combined with a transaction. + A serialized rollback is only allowed when combined with a transaction. """ def apifun( - transaction: bool = False, - reset_sequences: bool = False, + transaction=False, + reset_sequences=False, + serialized_rollback=False, databases: "_DjangoDbDatabases" = None, - ) -> "_DjangoDb": - return transaction, reset_sequences, databases + ): + return transaction, reset_sequences, serialized_rollback, databases return apifun(*marker.args, **marker.kwargs) diff --git a/tests/test_database.py b/tests/test_database.py index 9b5a88bd..73a5c6a4 100644 --- a/tests/test_database.py +++ b/tests/test_database.py @@ -48,10 +48,19 @@ def non_zero_sequences_counter(db: None) -> None: class TestDatabaseFixtures: """Tests for the different database fixtures.""" - @pytest.fixture(params=["db", "transactional_db", "django_db_reset_sequences"]) + @pytest.fixture( + params=[ + "db", + "transactional_db", + "django_db_reset_sequences", + "django_db_serialized_rollback", + ] + ) def all_dbs(self, request) -> None: if request.param == "django_db_reset_sequences": return request.getfixturevalue("django_db_reset_sequences") + elif request.param == "django_db_serialized_rollback": + return request.getfixturevalue("django_db_serialized_rollback") elif request.param == "transactional_db": return request.getfixturevalue("transactional_db") elif request.param == "db": @@ -77,7 +86,8 @@ def test_transactions_enabled(self, transactional_db: None) -> None: assert not connection.in_atomic_block def test_transactions_enabled_via_reset_seq( - self, django_db_reset_sequences: None, + self, + django_db_reset_sequences: None, ) -> None: if not connection.features.supports_transactions: pytest.skip("transactions required for this test") @@ -85,7 +95,10 @@ def test_transactions_enabled_via_reset_seq( assert not connection.in_atomic_block def test_django_db_reset_sequences_fixture( - self, db: None, django_testdir, non_zero_sequences_counter: None, + self, + db: None, + django_testdir, + non_zero_sequences_counter: None, ) -> None: if not db_supports_reset_sequences(): @@ -224,32 +237,42 @@ def test_reset_sequences_enabled(self, request) -> None: marker = request.node.get_closest_marker("django_db") assert marker.kwargs["reset_sequences"] - @pytest.mark.django_db(databases=['default', 'replica', 'second']) + @pytest.mark.django_db + def test_serialized_rollback_disabled(self, request): + marker = request.node.get_closest_marker("django_db") + assert not marker.kwargs + + @pytest.mark.django_db(serialized_rollback=True) + def test_serialized_rollback_enabled(self, request): + marker = request.node.get_closest_marker("django_db") + assert marker.kwargs["serialized_rollback"] + + @pytest.mark.django_db(databases=["default", "replica", "second"]) def test_databases(self, request) -> None: marker = request.node.get_closest_marker("django_db") - assert marker.kwargs["databases"] == ['default', 'replica', 'second'] + assert marker.kwargs["databases"] == ["default", "replica", "second"] - @pytest.mark.django_db(databases=['second']) + @pytest.mark.django_db(databases=["second"]) def test_second_database(self, request) -> None: SecondItem.objects.create(name="spam") - @pytest.mark.django_db(databases=['default']) + @pytest.mark.django_db(databases=["default"]) def test_not_allowed_database(self, request) -> None: - with pytest.raises(AssertionError, match='not allowed'): + with pytest.raises(AssertionError, match="not allowed"): SecondItem.objects.count() - with pytest.raises(AssertionError, match='not allowed'): + with pytest.raises(AssertionError, match="not allowed"): SecondItem.objects.create(name="spam") - @pytest.mark.django_db(databases=['replica']) + @pytest.mark.django_db(databases=["replica"]) def test_replica_database(self, request) -> None: - Item.objects.using('replica').count() + Item.objects.using("replica").count() - @pytest.mark.django_db(databases=['replica']) + @pytest.mark.django_db(databases=["replica"]) def test_replica_database_not_allowed(self, request) -> None: - with pytest.raises(AssertionError, match='not allowed'): + with pytest.raises(AssertionError, match="not allowed"): Item.objects.count() - @pytest.mark.django_db(databases='__all__') + @pytest.mark.django_db(databases="__all__") def test_all_databases(self, request) -> None: Item.objects.count() Item.objects.create(name="spam")