Skip to content

Commit 7108a0c

Browse files
committed
Make QuerySet[_T] an external alias to _QuerySet[_T, _T].
This currently has the drawback that error messages display the internal type _QuerySet, with both type arguments. See also discussion on typeddjango#661 and typeddjango#608. Fixes typeddjango#635: QuerySet methods on Managers (like .all()) now return QuerySets rather than Managers. Address code review by @sobolevn.
1 parent 751ae7f commit 7108a0c

File tree

18 files changed

+183
-95
lines changed

18 files changed

+183
-95
lines changed

django-stubs/db/models/manager.pyi

Lines changed: 95 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,30 @@
1-
from typing import Any, Dict, Iterable, List, Optional, Tuple, Type, TypeVar, Union
1+
import datetime
2+
from typing import (
3+
Any,
4+
Dict,
5+
Generic,
6+
Iterable,
7+
Iterator,
8+
List,
9+
MutableMapping,
10+
Optional,
11+
Sequence,
12+
Tuple,
13+
Type,
14+
TypeVar,
15+
Union,
16+
)
217

18+
from django.db.models import Combinable
319
from django.db.models.base import Model
4-
from django.db.models.query import QuerySet
20+
from django.db.models.query import QuerySet, RawQuerySet
21+
22+
from django_stubs_ext import ValuesQuerySet
523

624
_T = TypeVar("_T", bound=Model, covariant=True)
725
_M = TypeVar("_M", bound="BaseManager")
826

9-
class BaseManager(QuerySet[_T]):
27+
class BaseManager(Generic[_T]):
1028
creation_counter: int = ...
1129
auto_created: bool = ...
1230
use_in_migrations: bool = ...
@@ -24,6 +42,80 @@ class BaseManager(QuerySet[_T]):
2442
def contribute_to_class(self, model: Type[Model], name: str) -> None: ...
2543
def db_manager(self: _M, using: Optional[str] = ..., hints: Optional[Dict[str, Model]] = ...) -> _M: ...
2644
def get_queryset(self) -> QuerySet[_T]: ...
45+
# NOTE: The following methods are in common with QuerySet, but note that the use of QuerySet as a return type
46+
# rather than a self-type (_QS), since Manager's QuerySet-like methods return QuerySets and not Managers.
47+
def iterator(self, chunk_size: int = ...) -> Iterator[_T]: ...
48+
def aggregate(self, *args: Any, **kwargs: Any) -> Dict[str, Any]: ...
49+
def get(self, *args: Any, **kwargs: Any) -> _T: ...
50+
def create(self, *args: Any, **kwargs: Any) -> _T: ...
51+
def bulk_create(
52+
self, objs: Iterable[_T], batch_size: Optional[int] = ..., ignore_conflicts: bool = ...
53+
) -> List[_T]: ...
54+
def bulk_update(self, objs: Iterable[_T], fields: Sequence[str], batch_size: Optional[int] = ...) -> None: ...
55+
def get_or_create(self, defaults: Optional[MutableMapping[str, Any]] = ..., **kwargs: Any) -> Tuple[_T, bool]: ...
56+
def update_or_create(
57+
self, defaults: Optional[MutableMapping[str, Any]] = ..., **kwargs: Any
58+
) -> Tuple[_T, bool]: ...
59+
def earliest(self, *fields: Any, field_name: Optional[Any] = ...) -> _T: ...
60+
def latest(self, *fields: Any, field_name: Optional[Any] = ...) -> _T: ...
61+
def first(self) -> Optional[_T]: ...
62+
def last(self) -> Optional[_T]: ...
63+
def in_bulk(self, id_list: Iterable[Any] = ..., *, field_name: str = ...) -> Dict[Any, _T]: ...
64+
def delete(self) -> Tuple[int, Dict[str, int]]: ...
65+
def update(self, **kwargs: Any) -> int: ...
66+
def exists(self) -> bool: ...
67+
def explain(self, *, format: Optional[Any] = ..., **options: Any) -> str: ...
68+
def raw(
69+
self,
70+
raw_query: str,
71+
params: Any = ...,
72+
translations: Optional[Dict[str, str]] = ...,
73+
using: Optional[str] = ...,
74+
) -> RawQuerySet: ...
75+
# The type of values may be overridden to be more specific in the mypy plugin, depending on the fields param
76+
def values(self, *fields: Union[str, Combinable], **expressions: Any) -> ValuesQuerySet[_T, Dict[str, Any]]: ...
77+
# The type of values_list may be overridden to be more specific in the mypy plugin, depending on the fields param
78+
def values_list(
79+
self, *fields: Union[str, Combinable], flat: bool = ..., named: bool = ...
80+
) -> ValuesQuerySet[_T, Any]: ...
81+
def dates(self, field_name: str, kind: str, order: str = ...) -> ValuesQuerySet[_T, datetime.date]: ...
82+
def datetimes(
83+
self, field_name: str, kind: str, order: str = ..., tzinfo: Optional[datetime.tzinfo] = ...
84+
) -> ValuesQuerySet[_T, datetime.datetime]: ...
85+
def none(self) -> QuerySet[_T]: ...
86+
def all(self) -> QuerySet[_T]: ...
87+
def filter(self, *args: Any, **kwargs: Any) -> QuerySet[_T]: ...
88+
def exclude(self, *args: Any, **kwargs: Any) -> QuerySet[_T]: ...
89+
def complex_filter(self, filter_obj: Any) -> QuerySet[_T]: ...
90+
def count(self) -> int: ...
91+
def union(self, *other_qs: Any, all: bool = ...) -> QuerySet[_T]: ...
92+
def intersection(self, *other_qs: Any) -> QuerySet[_T]: ...
93+
def difference(self, *other_qs: Any) -> QuerySet[_T]: ...
94+
def select_for_update(
95+
self, nowait: bool = ..., skip_locked: bool = ..., of: Sequence[str] = ..., no_key: bool = ...
96+
) -> QuerySet[_T]: ...
97+
def select_related(self, *fields: Any) -> QuerySet[_T]: ...
98+
def prefetch_related(self, *lookups: Any) -> QuerySet[_T]: ...
99+
def annotate(self, *args: Any, **kwargs: Any) -> QuerySet[_T]: ...
100+
def alias(self, *args: Any, **kwargs: Any) -> QuerySet[_T]: ...
101+
def order_by(self, *field_names: Any) -> QuerySet[_T]: ...
102+
def distinct(self, *field_names: Any) -> QuerySet[_T]: ...
103+
# extra() return type won't be supported any time soon
104+
def extra(
105+
self,
106+
select: Optional[Dict[str, Any]] = ...,
107+
where: Optional[List[str]] = ...,
108+
params: Optional[List[Any]] = ...,
109+
tables: Optional[List[str]] = ...,
110+
order_by: Optional[Sequence[str]] = ...,
111+
select_params: Optional[Sequence[Any]] = ...,
112+
) -> QuerySet[Any]: ...
113+
def reverse(self) -> QuerySet[_T]: ...
114+
def defer(self, *fields: Any) -> QuerySet[_T]: ...
115+
def only(self, *fields: Any) -> QuerySet[_T]: ...
116+
def using(self, alias: Optional[str]) -> QuerySet[_T]: ...
117+
@property
118+
def ordered(self) -> bool: ...
27119

28120
class Manager(BaseManager[_T]): ...
29121

django-stubs/db/models/query.pyi

Lines changed: 22 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -28,9 +28,10 @@ from django.db.models.query_utils import Q as Q # noqa: F401
2828
from django.db.models.sql.query import Query, RawQuery
2929

3030
_T = TypeVar("_T", bound=models.Model, covariant=True)
31-
_QS = TypeVar("_QS", bound="QuerySet")
31+
_Row = TypeVar("_Row", covariant=True)
32+
_QS = TypeVar("_QS", bound="_QuerySet")
3233

33-
class QuerySet(Generic[_T], Collection[_T], Reversible[_T], Sized):
34+
class _QuerySet(Generic[_T, _Row], Collection[_Row], Reversible[_Row], Sized):
3435
model: Type[_T]
3536
query: Query
3637
def __init__(
@@ -47,11 +48,13 @@ class QuerySet(Generic[_T], Collection[_T], Reversible[_T], Sized):
4748
def __class_getitem__(cls: Type[_QS], item: Type[_T]) -> Type[_QS]: ...
4849
def __getstate__(self) -> Dict[str, Any]: ...
4950
# Technically, the other QuerySet must be of the same type _T, but _T is covariant
50-
def __and__(self: _QS, other: QuerySet[_T]) -> _QS: ...
51-
def __or__(self: _QS, other: QuerySet[_T]) -> _QS: ...
52-
def iterator(self, chunk_size: int = ...) -> Iterator[_T]: ...
51+
def __and__(self: _QS, other: _QuerySet[_T, _Row]) -> _QS: ...
52+
def __or__(self: _QS, other: _QuerySet[_T, _Row]) -> _QS: ...
53+
# IMPORTANT: When updating any of the following methods' signatures, please ALSO modify
54+
# the corresponding method in BaseManager.
55+
def iterator(self, chunk_size: int = ...) -> Iterator[_Row]: ...
5356
def aggregate(self, *args: Any, **kwargs: Any) -> Dict[str, Any]: ...
54-
def get(self, *args: Any, **kwargs: Any) -> _T: ...
57+
def get(self, *args: Any, **kwargs: Any) -> _Row: ...
5558
def create(self, *args: Any, **kwargs: Any) -> _T: ...
5659
def bulk_create(
5760
self, objs: Iterable[_T], batch_size: Optional[int] = ..., ignore_conflicts: bool = ...
@@ -61,10 +64,10 @@ class QuerySet(Generic[_T], Collection[_T], Reversible[_T], Sized):
6164
def update_or_create(
6265
self, defaults: Optional[MutableMapping[str, Any]] = ..., **kwargs: Any
6366
) -> Tuple[_T, bool]: ...
64-
def earliest(self, *fields: Any, field_name: Optional[Any] = ...) -> _T: ...
65-
def latest(self, *fields: Any, field_name: Optional[Any] = ...) -> _T: ...
66-
def first(self) -> Optional[_T]: ...
67-
def last(self) -> Optional[_T]: ...
67+
def earliest(self, *fields: Any, field_name: Optional[Any] = ...) -> _Row: ...
68+
def latest(self, *fields: Any, field_name: Optional[Any] = ...) -> _Row: ...
69+
def first(self) -> Optional[_Row]: ...
70+
def last(self) -> Optional[_Row]: ...
6871
def in_bulk(self, id_list: Iterable[Any] = ..., *, field_name: str = ...) -> Dict[Any, _T]: ...
6972
def delete(self) -> Tuple[int, Dict[str, int]]: ...
7073
def update(self, **kwargs: Any) -> int: ...
@@ -78,15 +81,15 @@ class QuerySet(Generic[_T], Collection[_T], Reversible[_T], Sized):
7881
using: Optional[str] = ...,
7982
) -> RawQuerySet: ...
8083
# The type of values may be overridden to be more specific in the mypy plugin, depending on the fields param
81-
def values(self, *fields: Union[str, Combinable], **expressions: Any) -> _ValuesQuerySet[_T, Dict[str, Any]]: ...
84+
def values(self, *fields: Union[str, Combinable], **expressions: Any) -> _QuerySet[_T, Dict[str, Any]]: ...
8285
# The type of values_list may be overridden to be more specific in the mypy plugin, depending on the fields param
8386
def values_list(
8487
self, *fields: Union[str, Combinable], flat: bool = ..., named: bool = ...
85-
) -> _ValuesQuerySet[_T, Any]: ...
86-
def dates(self, field_name: str, kind: str, order: str = ...) -> _ValuesQuerySet[_T, datetime.date]: ...
88+
) -> _QuerySet[_T, Any]: ...
89+
def dates(self, field_name: str, kind: str, order: str = ...) -> _QuerySet[_T, datetime.date]: ...
8790
def datetimes(
8891
self, field_name: str, kind: str, order: str = ..., tzinfo: Optional[datetime.tzinfo] = ...
89-
) -> _ValuesQuerySet[_T, datetime.datetime]: ...
92+
) -> _QuerySet[_T, datetime.datetime]: ...
9093
def none(self: _QS) -> _QS: ...
9194
def all(self: _QS) -> _QS: ...
9295
def filter(self: _QS, *args: Any, **kwargs: Any) -> _QS: ...
@@ -114,7 +117,7 @@ class QuerySet(Generic[_T], Collection[_T], Reversible[_T], Sized):
114117
tables: Optional[List[str]] = ...,
115118
order_by: Optional[Sequence[str]] = ...,
116119
select_params: Optional[Sequence[Any]] = ...,
117-
) -> QuerySet[Any]: ...
120+
) -> _QuerySet[Any, Any]: ...
118121
def reverse(self: _QS) -> _QS: ...
119122
def defer(self: _QS, *fields: Any) -> _QS: ...
120123
def only(self: _QS, *fields: Any) -> _QS: ...
@@ -124,28 +127,13 @@ class QuerySet(Generic[_T], Collection[_T], Reversible[_T], Sized):
124127
@property
125128
def db(self) -> str: ...
126129
def resolve_expression(self, *args: Any, **kwargs: Any) -> Any: ...
127-
def __iter__(self) -> Iterator[_T]: ...
130+
def __iter__(self) -> Iterator[_Row]: ...
128131
def __contains__(self, x: object) -> bool: ...
129132
@overload
130-
def __getitem__(self, i: int) -> _T: ...
131-
@overload
132-
def __getitem__(self: _QS, s: slice) -> _QS: ...
133-
def __reversed__(self) -> Iterator[_T]: ...
134-
135-
_Row = TypeVar("_Row", covariant=True)
136-
137-
class _ValuesQuerySet(QuerySet[_T], Collection[_Row], Reversible[_Row], Sized):
138-
def __iter__(self) -> Iterator[_Row]: ... # type: ignore
139-
@overload # type: ignore
140133
def __getitem__(self, i: int) -> _Row: ...
141134
@overload
142135
def __getitem__(self: _QS, s: slice) -> _QS: ...
143-
def iterator(self, chunk_size: int = ...) -> Iterator[_Row]: ... # type: ignore
144-
def get(self, *args: Any, **kwargs: Any) -> _Row: ... # type: ignore
145-
def earliest(self, *fields: Any, field_name: Optional[Any] = ...) -> _Row: ... # type: ignore
146-
def latest(self, *fields: Any, field_name: Optional[Any] = ...) -> _Row: ... # type: ignore
147-
def first(self) -> Optional[_Row]: ... # type: ignore
148-
def last(self) -> Optional[_Row]: ... # type: ignore
136+
def __reversed__(self) -> Iterator[_Row]: ...
149137

150138
class RawQuerySet(Iterable[_T], Sized):
151139
query: RawQuery
@@ -179,6 +167,8 @@ class RawQuerySet(Iterable[_T], Sized):
179167
def resolve_model_init_order(self) -> Tuple[List[str], List[int], List[Tuple[str, int]]]: ...
180168
def using(self, alias: Optional[str]) -> RawQuerySet[_T]: ...
181169

170+
QuerySet = _QuerySet[_T, _T]
171+
182172
class Prefetch(object):
183173
def __init__(self, lookup: str, queryset: Optional[QuerySet] = ..., to_attr: Optional[str] = ...) -> None: ...
184174
def __getstate__(self) -> Dict[str, Any]: ...

django-stubs/views/generic/list.pyi

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@ from django.db.models.query import QuerySet
66
from django.http import HttpRequest, HttpResponse
77
from django.views.generic.base import ContextMixin, TemplateResponseMixin, View
88

9-
T = TypeVar("T", bound=Model)
9+
T = TypeVar("T", bound=Model, covariant=True)
1010

1111
class MultipleObjectMixin(Generic[T], ContextMixin):
1212
allow_empty: bool = ...
Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,8 @@
11
import typing
22

33
if typing.TYPE_CHECKING:
4-
from django.db.models.query import _T, _Row, _ValuesQuerySet
4+
from django.db.models.query import _T, _QuerySet, _Row
55

6-
ValuesQuerySet = _ValuesQuerySet[_T, _Row]
6+
ValuesQuerySet = _QuerySet[_T, _Row]
77
else:
88
ValuesQuerySet = typing.Any

mypy_django_plugin/django/context.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,6 @@
1919
from mypy.types import AnyType, Instance
2020
from mypy.types import Type as MypyType
2121
from mypy.types import TypeOfAny, UnionType
22-
from typing_extensions import Final
2322

2423
from mypy_django_plugin.lib import fullnames, helpers
2524
from mypy_django_plugin.lib.fullnames import WITH_ANNOTATIONS_FULLNAME
@@ -116,10 +115,10 @@ def model_modules(self) -> Dict[str, Set[Type[Model]]]:
116115

117116
def get_model_class_by_fullname(self, fullname: str) -> Optional[Type[Model]]:
118117
"""Returns None if Model is abstract"""
119-
ANNOTATED_PREFIX: Final = WITH_ANNOTATIONS_FULLNAME + "["
120-
if fullname.startswith(ANNOTATED_PREFIX):
118+
annotated_prefix = WITH_ANNOTATIONS_FULLNAME + "["
119+
if fullname.startswith(annotated_prefix):
121120
# For our "annotated models", extract the original model fullname
122-
fullname = fullname[len(ANNOTATED_PREFIX) :].rstrip("]")
121+
fullname = fullname[len(annotated_prefix) :].rstrip("]")
123122

124123
module, _, model_cls_name = fullname.rpartition(".")
125124
for model_cls in self.model_modules.get(module, set()):

mypy_django_plugin/lib/fullnames.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -11,8 +11,7 @@
1111
MANYTOMANY_FIELD_FULLNAME = "django.db.models.fields.related.ManyToManyField"
1212
DUMMY_SETTINGS_BASE_CLASS = "django.conf._DjangoConfLazyObject"
1313

14-
QUERYSET_CLASS_FULLNAME = "django.db.models.query.QuerySet"
15-
VALUES_QUERYSET_CLASS_FULLNAME = "django.db.models.query._ValuesQuerySet"
14+
QUERYSET_CLASS_FULLNAME = "django.db.models.query._QuerySet"
1615
BASE_MANAGER_CLASS_FULLNAME = "django.db.models.manager.BaseManager"
1716
MANAGER_CLASS_FULLNAME = "django.db.models.manager.Manager"
1817
RELATED_MANAGER_CLASS = "django.db.models.manager.RelatedManager"

mypy_django_plugin/lib/helpers.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -62,7 +62,7 @@ def is_toml(filename: str) -> bool:
6262
def lookup_fully_qualified_sym(fullname: str, all_modules: Dict[str, MypyFile]) -> Optional[SymbolTableNode]:
6363
if "." not in fullname:
6464
return None
65-
if "[" in fullname:
65+
if "[" in fullname and "]" in fullname:
6666
# We sometimes generate fake fullnames like a.b.C[x.y.Z] to provide a better representation to users
6767
# Make sure that we handle lookups of those types of names correctly if the part inside [] contains "."
6868
bracket_start = fullname.index("[")

mypy_django_plugin/main.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -266,27 +266,28 @@ def get_method_hook(self, fullname: str) -> Optional[Callable[[MethodContext], M
266266
if info and info.has_base(fullnames.FORM_MIXIN_CLASS_FULLNAME):
267267
return forms.extract_proper_type_for_get_form
268268

269+
manager_classes = self._get_current_manager_bases()
270+
269271
if method_name == "values":
270272
info = self._get_typeinfo_or_none(class_fullname)
271-
if info and info.has_base(fullnames.QUERYSET_CLASS_FULLNAME):
273+
if info and info.has_base(fullnames.QUERYSET_CLASS_FULLNAME) or class_fullname in manager_classes:
272274
return partial(querysets.extract_proper_type_queryset_values, django_context=self.django_context)
273275

274276
if method_name == "values_list":
275277
info = self._get_typeinfo_or_none(class_fullname)
276-
if info and info.has_base(fullnames.QUERYSET_CLASS_FULLNAME):
278+
if info and info.has_base(fullnames.QUERYSET_CLASS_FULLNAME) or class_fullname in manager_classes:
277279
return partial(querysets.extract_proper_type_queryset_values_list, django_context=self.django_context)
278280

279281
if method_name == "annotate":
280282
info = self._get_typeinfo_or_none(class_fullname)
281-
if info and info.has_base(fullnames.QUERYSET_CLASS_FULLNAME):
283+
if info and info.has_base(fullnames.QUERYSET_CLASS_FULLNAME) or class_fullname in manager_classes:
282284
return partial(querysets.extract_proper_type_queryset_annotate, django_context=self.django_context)
283285

284286
if method_name == "get_field":
285287
info = self._get_typeinfo_or_none(class_fullname)
286288
if info and info.has_base(fullnames.OPTIONS_CLASS_FULLNAME):
287289
return partial(meta.return_proper_field_type_from_get_field, django_context=self.django_context)
288290

289-
manager_classes = self._get_current_manager_bases()
290291
if class_fullname in manager_classes and method_name == "create":
291292
return partial(init_create.redefine_and_typecheck_model_create, django_context=self.django_context)
292293
if class_fullname in manager_classes and method_name in {"filter", "get", "exclude"}:

mypy_django_plugin/transformers/models.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -197,7 +197,6 @@ def run_with_model_cls(self, model_cls: Type[Model]) -> None:
197197
for manager_name, manager in model_cls._meta.managers_map.items():
198198
manager_class_name = manager.__class__.__name__
199199
manager_fullname = helpers.get_class_fullname(manager.__class__)
200-
201200
try:
202201
manager_info = self.lookup_typeinfo_or_incomplete_defn_error(manager_fullname)
203202
except helpers.IncompleteDefnException as exc:

0 commit comments

Comments
 (0)