Skip to content

Commit 82eae91

Browse files
authored
perf: only clear assertion _extension when overridden (#172)
* test: extension is not cleared when not overridden * perf: only clear extension when overridden
1 parent 3ba0e4c commit 82eae91

File tree

2 files changed

+12
-5
lines changed

2 files changed

+12
-5
lines changed

src/syrupy/assertion.py

Lines changed: 11 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
from gettext import gettext
22
from typing import (
33
TYPE_CHECKING,
4+
Callable,
45
Dict,
56
List,
67
Optional,
@@ -44,10 +45,9 @@ class SnapshotAssertion:
4445
_test_location: "TestLocation" = attr.ib(kw_only=True)
4546
_update_snapshots: bool = attr.ib(kw_only=True)
4647
_extension: Optional["AbstractSyrupyExtension"] = attr.ib(init=False, default=None)
47-
_executions: int = attr.ib(init=False, default=0, kw_only=True)
48-
_execution_results: Dict[int, "AssertionResult"] = attr.ib(
49-
init=False, factory=dict, kw_only=True
50-
)
48+
_executions: int = attr.ib(init=False, default=0)
49+
_execution_results: Dict[int, "AssertionResult"] = attr.ib(init=False, factory=dict)
50+
_post_assert_actions: List[Callable[..., None]] = attr.ib(init=False, factory=list)
5151

5252
def __attrs_post_init__(self) -> None:
5353
self._session.register_request(self)
@@ -108,6 +108,11 @@ def __call__(
108108
"""
109109
if extension_class:
110110
self._extension = self.__init_extension(extension_class)
111+
112+
def clear_extension() -> None:
113+
self._extension = None
114+
115+
self._post_assert_actions.append(clear_extension)
111116
return self
112117

113118
def __repr__(self) -> str:
@@ -155,7 +160,8 @@ def _post_assert(self) -> None:
155160
"""
156161
Restores assertion instance options
157162
"""
158-
self._extension = None
163+
while self._post_assert_actions:
164+
self._post_assert_actions.pop()()
159165

160166
def _recall_data(self, index: int) -> Optional["SerializableData"]:
161167
try:

tests/test_extension_image.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -51,5 +51,6 @@ def test_multiple_snapshot_extensions(snapshot):
5151
"""
5252
assert actual_svg == snapshot(extension_class=SVGImageSnapshotExtension)
5353
assert actual_svg == snapshot # uses initial extension class
54+
assert snapshot._extension is not None
5455
assert actual_png == snapshot(extension_class=PNGImageSnapshotExtension)
5556
assert actual_svg == snapshot(extension_class=SVGImageSnapshotExtension)

0 commit comments

Comments
 (0)