|
1 | 1 | from gettext import gettext
|
2 | 2 | from typing import (
|
3 | 3 | TYPE_CHECKING,
|
| 4 | + Callable, |
4 | 5 | Dict,
|
5 | 6 | List,
|
6 | 7 | Optional,
|
@@ -44,10 +45,9 @@ class SnapshotAssertion:
|
44 | 45 | _test_location: "TestLocation" = attr.ib(kw_only=True)
|
45 | 46 | _update_snapshots: bool = attr.ib(kw_only=True)
|
46 | 47 | _extension: Optional["AbstractSyrupyExtension"] = attr.ib(init=False, default=None)
|
47 |
| - _executions: int = attr.ib(init=False, default=0, kw_only=True) |
48 |
| - _execution_results: Dict[int, "AssertionResult"] = attr.ib( |
49 |
| - init=False, factory=dict, kw_only=True |
50 |
| - ) |
| 48 | + _executions: int = attr.ib(init=False, default=0) |
| 49 | + _execution_results: Dict[int, "AssertionResult"] = attr.ib(init=False, factory=dict) |
| 50 | + _post_assert_actions: List[Callable[..., None]] = attr.ib(init=False, factory=list) |
51 | 51 |
|
52 | 52 | def __attrs_post_init__(self) -> None:
|
53 | 53 | self._session.register_request(self)
|
@@ -108,6 +108,11 @@ def __call__(
|
108 | 108 | """
|
109 | 109 | if extension_class:
|
110 | 110 | self._extension = self.__init_extension(extension_class)
|
| 111 | + |
| 112 | + def clear_extension() -> None: |
| 113 | + self._extension = None |
| 114 | + |
| 115 | + self._post_assert_actions.append(clear_extension) |
111 | 116 | return self
|
112 | 117 |
|
113 | 118 | def __repr__(self) -> str:
|
@@ -155,7 +160,8 @@ def _post_assert(self) -> None:
|
155 | 160 | """
|
156 | 161 | Restores assertion instance options
|
157 | 162 | """
|
158 |
| - self._extension = None |
| 163 | + while self._post_assert_actions: |
| 164 | + self._post_assert_actions.pop()() |
159 | 165 |
|
160 | 166 | def _recall_data(self, index: int) -> Optional["SerializableData"]:
|
161 | 167 | try:
|
|
0 commit comments