Skip to content

Commit 194258a

Browse files
authored
Merge pull request #23 from syastrov/better-types-for-transaction-atomic
Add better typings plus test for transaction.atomic.
2 parents 116aa2c + 67c9943 commit 194258a

File tree

2 files changed

+62
-9
lines changed

2 files changed

+62
-9
lines changed

django-stubs/db/transaction.pyi

Lines changed: 12 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,4 @@
1-
from contextlib import ContextDecorator
2-
from typing import Any, Callable, Optional, Union, Iterator, overload, ContextManager
1+
from typing import Any, Callable, Optional, overload, TypeVar
32

43
from django.db import ProgrammingError
54

@@ -18,19 +17,23 @@ def get_rollback(using: None = ...) -> bool: ...
1817
def set_rollback(rollback: bool, using: Optional[str] = ...) -> None: ...
1918
def on_commit(func: Callable, using: None = ...) -> None: ...
2019

21-
class Atomic(ContextDecorator):
20+
_C = TypeVar("_C", bound=Callable) # Any callable
21+
22+
# Don't inherit from ContextDecorator, so we can provide a more specific signature for __call__
23+
class Atomic:
2224
using: Optional[str] = ...
2325
savepoint: bool = ...
2426
def __init__(self, using: Optional[str], savepoint: bool) -> None: ...
27+
# When decorating, return the decorated function as-is, rather than clobbering it as ContextDecorator does.
28+
def __call__(self, func: _C) -> _C: ...
2529
def __enter__(self) -> None: ...
2630
def __exit__(self, exc_type: None, exc_value: None, traceback: None) -> None: ...
2731

32+
# Bare decorator
2833
@overload
29-
def atomic() -> Atomic: ...
30-
@overload
31-
def atomic(using: Optional[str] = ...,) -> ContextManager[Atomic]: ...
32-
@overload
33-
def atomic(using: Callable = ...) -> Callable: ...
34+
def atomic(using: _C) -> _C: ...
35+
36+
# Decorator or context-manager with parameters
3437
@overload
35-
def atomic(using: Optional[str] = ..., savepoint: bool = ...) -> ContextManager[Atomic]: ...
38+
def atomic(using: Optional[str] = None, savepoint: bool = True) -> Atomic: ...
3639
def non_atomic_requests(using: Callable = ...) -> Callable: ...

test-data/typecheck/transaction.test

Lines changed: 50 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,50 @@
1+
[CASE test_transaction_atomic]
2+
3+
from django.db import transaction
4+
5+
with transaction.atomic():
6+
pass
7+
8+
with transaction.atomic(using="mydb"):
9+
pass
10+
11+
with transaction.atomic(using="mydb", savepoint=False):
12+
pass
13+
14+
@transaction.atomic()
15+
def decorated_func(param1: str, param2: int) -> bool:
16+
pass
17+
18+
# Ensure that the function's type is preserved
19+
reveal_type(decorated_func) # E: Revealed type is 'def (param1: builtins.str, param2: builtins.int) -> builtins.bool'
20+
21+
@transaction.atomic(using="mydb")
22+
def decorated_func_using(param1: str, param2: int) -> bool:
23+
pass
24+
25+
# Ensure that the function's type is preserved
26+
reveal_type(decorated_func_using) # E: Revealed type is 'def (param1: builtins.str, param2: builtins.int) -> builtins.bool'
27+
28+
class ClassWithAtomicMethod:
29+
# Bare decorator
30+
@transaction.atomic
31+
def atomic_method1(self, abc: int) -> str:
32+
pass
33+
34+
@transaction.atomic(savepoint=True)
35+
def atomic_method2(self):
36+
pass
37+
38+
@transaction.atomic(using="db", savepoint=True)
39+
def atomic_method3(self, myparam: str) -> int:
40+
pass
41+
42+
ClassWithAtomicMethod().atomic_method1("abc") # E: Argument 1 to "atomic_method1" of "ClassWithAtomicMethod" has incompatible type "str"; expected "int"
43+
44+
# Ensure that the method's type is preserved
45+
reveal_type(ClassWithAtomicMethod().atomic_method1) # E: Revealed type is 'def (abc: builtins.int) -> builtins.str'
46+
47+
# Ensure that the method's type is preserved
48+
reveal_type(ClassWithAtomicMethod().atomic_method3) # E: Revealed type is 'def (myparam: builtins.str) -> builtins.int'
49+
50+
[out]

0 commit comments

Comments
 (0)