Skip to content

Commit b825ec1

Browse files
committed
Make context retries respect config
1 parent b65d1dd commit b825ec1

File tree

3 files changed

+40
-22
lines changed

3 files changed

+40
-22
lines changed

src/stamina/_core.py

Lines changed: 10 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -105,6 +105,11 @@ def from_params(
105105
)
106106

107107
def __iter__(self) -> _t.Retrying:
108+
if not _CONFIG.is_active:
109+
return _t.Retrying(
110+
reraise=True, stop=_t.stop_after_attempt(1)
111+
).__iter__()
112+
108113
return _t.Retrying(
109114
before_sleep=_make_before_sleep(
110115
self._name, _CONFIG.on_retry, self._args, self._kw
@@ -115,6 +120,11 @@ def __iter__(self) -> _t.Retrying:
115120
).__iter__()
116121

117122
def __aiter__(self) -> _t.AsyncRetrying:
123+
if not _CONFIG.is_active:
124+
return _t.AsyncRetrying(
125+
reraise=True, stop=_t.stop_after_attempt(1)
126+
).__aiter__()
127+
118128
return _t.AsyncRetrying(
119129
before_sleep=_make_before_sleep(
120130
self._name, _CONFIG.on_retry, self._args, self._kw
@@ -229,9 +239,6 @@ def retry_decorator(wrapped: Callable[P, T]) -> Callable[P, T]:
229239

230240
@wraps(wrapped)
231241
def sync_inner(*args: P.args, **kw: P.kwargs) -> T: # type: ignore[return]
232-
if not _CONFIG.is_active:
233-
return wrapped(*args, **kw)
234-
235242
for attempt in retry_ctx.with_name( # noqa: RET503
236243
name, args, kw
237244
):
@@ -242,9 +249,6 @@ def sync_inner(*args: P.args, **kw: P.kwargs) -> T: # type: ignore[return]
242249

243250
@wraps(wrapped)
244251
async def async_inner(*args: P.args, **kw: P.kwargs) -> T: # type: ignore[return]
245-
if not _CONFIG.is_active:
246-
return await wrapped(*args, **kw) # type: ignore[no-any-return,misc]
247-
248252
async for attempt in retry_ctx.with_name( # noqa: RET503
249253
name, args, kw
250254
):

tests/test_async.py

Lines changed: 25 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,6 @@
22
#
33
# SPDX-License-Identifier: MIT
44

5-
from unittest.mock import Mock
65

76
import pytest
87

@@ -80,36 +79,52 @@ async def f():
8079
await f()
8180

8281

83-
async def test_retry_inactive(monkeypatch):
82+
async def test_retry_inactive():
8483
"""
8584
If inactive, don't retry.
8685
"""
86+
num_called = 0
8787

8888
@stamina.retry(on=Exception)
8989
async def f():
90+
nonlocal num_called
91+
num_called += 1
9092
raise Exception("passed")
9193

9294
stamina.set_active(False)
9395

94-
retrying = Mock()
95-
monkeypatch.setattr(stamina._core._t, "AsyncRetrying", retrying)
96-
9796
with pytest.raises(Exception, match="passed"):
9897
await f()
9998

100-
retrying.assert_not_called()
99+
assert 1 == num_called
101100

102101

103102
async def test_retry_block():
104103
"""
105104
Async retry_context blocks are retried.
106105
"""
107-
i = 0
106+
num_called = 0
108107

109108
async for attempt in stamina.retry_context(on=ValueError, wait_max=0):
110109
with attempt:
111-
i += 1
112-
if i < 2:
110+
num_called += 1
111+
if num_called < 2:
113112
raise ValueError
114113

115-
assert 2 == i
114+
assert 2 == num_called
115+
116+
117+
async def test_retry_blocks_can_be_disabled():
118+
"""
119+
Async context retries respect the config.
120+
"""
121+
stamina.set_active(False)
122+
num_called = 0
123+
124+
with pytest.raises(Exception, match="passed"):
125+
async for attempt in stamina.retry_context(on=Exception, attempts=2):
126+
with attempt:
127+
num_called += 1
128+
raise Exception("passed")
129+
130+
assert 1 == num_called

tests/test_sync.py

Lines changed: 5 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,6 @@
22
#
33
# SPDX-License-Identifier: MIT
44

5-
from unittest.mock import Mock
65

76
import pytest
87
import tenacity
@@ -57,24 +56,24 @@ def f():
5756
f()
5857

5958

60-
def test_retry_inactive(monkeypatch):
59+
def test_retry_inactive():
6160
"""
6261
If inactive, don't retry.
6362
"""
63+
num_called = 0
6464

6565
@stamina.retry(on=Exception)
6666
def f():
67+
nonlocal num_called
68+
num_called += 1
6769
raise Exception("passed")
6870

6971
stamina.set_active(False)
7072

71-
retrying = Mock()
72-
monkeypatch.setattr(stamina._core._t, "Retrying", retrying)
73-
7473
with pytest.raises(Exception, match="passed"):
7574
f()
7675

77-
retrying.assert_not_called()
76+
assert 1 == num_called
7877

7978

8079
def test_retry_block():

0 commit comments

Comments
 (0)