Skip to content

Commit 324b961

Browse files
syastrovmkurnikov
authored andcommitted
Support returning the correct values for the different QuerySet methods when using .values() and .values_list(). (#33)
* Support returning the correct values for the different QuerySet methods when using .values() and .values_list(). * Fix slicing on QuerySet. Fix django queries test, and remove some ignored errors that are no longer needed. * Remove accidental change in RawQuerySet. * Readded some still-necessary ignores to aggregation django test. * Add more tests of first/last/earliest/last/__getitem__, per mkurnikov's comments. - Fix .iterator() * Re-add Iterator as base-class of QuerySet. * Make QuerySet a Collection. * - Fix return type for QuerySet.select_for_update(). - Use correct return type for QuerySet.dates() / QuerySet.datetimes(). - Use correct type params in return type for QuerySet.__and__ / QuerySet.__or__ - Re-add Sized as base class for QuerySet. - Add test of .all() for all _Row types. - Add test of .get() for all _Row types. - Remove some redundant QuerySet method tests. * Automatically fill in second type parameter for QuerySet. ... if second parameter is omitted.
1 parent 86c63d7 commit 324b961

File tree

6 files changed

+185
-73
lines changed

6 files changed

+185
-73
lines changed

django-stubs/db/models/manager.pyi

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@ from django.db.models.query import QuerySet
55

66
_T = TypeVar("_T", bound=Model, covariant=True)
77

8-
class BaseManager(QuerySet[_T]):
8+
class BaseManager(QuerySet[_T, _T]):
99
creation_counter: int = ...
1010
auto_created: bool = ...
1111
use_in_migrations: bool = ...
@@ -21,7 +21,7 @@ class BaseManager(QuerySet[_T]):
2121
def _get_queryset_methods(cls, queryset_class: type) -> Dict[str, Any]: ...
2222
def contribute_to_class(self, model: Type[Model], name: str) -> None: ...
2323
def db_manager(self, using: Optional[str] = ..., hints: Optional[Dict[str, Model]] = ...) -> Manager: ...
24-
def get_queryset(self) -> QuerySet[_T]: ...
24+
def get_queryset(self) -> QuerySet[_T, _T]: ...
2525

2626
class Manager(BaseManager[_T]): ...
2727

django-stubs/db/models/query.pyi

Lines changed: 58 additions & 46 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
import datetime
12
from typing import (
23
Any,
34
Dict,
@@ -13,6 +14,9 @@ from typing import (
1314
TypeVar,
1415
Union,
1516
overload,
17+
Generic,
18+
NamedTuple,
19+
Collection,
1620
)
1721

1822
from django.db.models.base import Model
@@ -46,7 +50,7 @@ class FlatValuesListIterable(BaseIterable):
4650

4751
_T = TypeVar("_T", bound=models.Model, covariant=True)
4852

49-
class QuerySet(Iterable[_T], Sized):
53+
class QuerySet(Generic[_T, _Row], Collection[_Row], Sized):
5054
query: Query
5155
def __init__(
5256
self,
@@ -58,32 +62,33 @@ class QuerySet(Iterable[_T], Sized):
5862
@classmethod
5963
def as_manager(cls) -> Manager[Any]: ...
6064
def __len__(self) -> int: ...
61-
def __iter__(self) -> Iterator[_T]: ...
65+
def __iter__(self) -> Iterator[_Row]: ...
66+
def __contains__(self, x: object) -> bool: ...
67+
@overload
68+
def __getitem__(self, i: int) -> _Row: ...
69+
@overload
70+
def __getitem__(self, s: slice) -> QuerySet[_T, _Row]: ...
6271
def __bool__(self) -> bool: ...
6372
def __class_getitem__(cls, item: Type[_T]):
6473
pass
6574
def __getstate__(self) -> Dict[str, Any]: ...
66-
@overload
67-
def __getitem__(self, k: int) -> _T: ...
68-
@overload
69-
def __getitem__(self, k: str) -> Any: ...
70-
@overload
71-
def __getitem__(self, k: slice) -> QuerySet[_T]: ...
72-
def __and__(self, other: QuerySet) -> QuerySet: ...
73-
def __or__(self, other: QuerySet) -> QuerySet: ...
74-
def iterator(self, chunk_size: int = ...) -> Iterator[_T]: ...
75+
# __and__ and __or__ ignore the other QuerySet's _Row type parameter because they use the same row type as the self QuerySet.
76+
# Technically, the other QuerySet must be of the same type _T, but _T is covariant
77+
def __and__(self, other: QuerySet[_T, Any]) -> QuerySet[_T, _Row]: ...
78+
def __or__(self, other: QuerySet[_T, Any]) -> QuerySet[_T, _Row]: ...
79+
def iterator(self, chunk_size: int = ...) -> Iterator[_Row]: ...
7580
def aggregate(self, *args: Any, **kwargs: Any) -> Dict[str, Any]: ...
76-
def get(self, *args: Any, **kwargs: Any) -> _T: ...
81+
def get(self, *args: Any, **kwargs: Any) -> _Row: ...
7782
def create(self, **kwargs: Any) -> _T: ...
7883
def bulk_create(self, objs: Iterable[Model], batch_size: Optional[int] = ...) -> List[_T]: ...
7984
def get_or_create(self, defaults: Optional[MutableMapping[str, Any]] = ..., **kwargs: Any) -> Tuple[_T, bool]: ...
8085
def update_or_create(
8186
self, defaults: Optional[MutableMapping[str, Any]] = ..., **kwargs: Any
8287
) -> Tuple[_T, bool]: ...
83-
def earliest(self, *fields: Any, field_name: Optional[Any] = ...) -> _T: ...
84-
def latest(self, *fields: Any, field_name: Optional[Any] = ...) -> _T: ...
85-
def first(self) -> Optional[_T]: ...
86-
def last(self) -> Optional[_T]: ...
88+
def earliest(self, *fields: Any, field_name: Optional[Any] = ...) -> _Row: ...
89+
def latest(self, *fields: Any, field_name: Optional[Any] = ...) -> _Row: ...
90+
def first(self) -> Optional[_Row]: ...
91+
def last(self) -> Optional[_Row]: ...
8792
def in_bulk(self, id_list: Iterable[Any] = ..., *, field_name: str = ...) -> Dict[Any, _T]: ...
8893
def delete(self) -> Tuple[int, Dict[str, int]]: ...
8994
def update(self, **kwargs: Any) -> int: ...
@@ -93,31 +98,38 @@ class QuerySet(Iterable[_T], Sized):
9398
def raw(
9499
self, raw_query: str, params: Any = ..., translations: Optional[Dict[str, str]] = ..., using: None = ...
95100
) -> RawQuerySet: ...
96-
def values(self, *fields: Union[str, Combinable], **expressions: Any) -> QuerySet: ...
97-
def values_list(self, *fields: Union[str, Combinable], flat: bool = ..., named: bool = ...) -> QuerySet: ...
98-
# @overload
99-
# def values_list(self, *fields: Union[str, Combinable], named: Literal[True]) -> NamedValuesListIterable: ...
100-
# @overload
101-
# def values_list(self, *fields: Union[str, Combinable], flat: Literal[True]) -> FlatValuesListIterable: ...
102-
# @overload
103-
# def values_list(self, *fields: Union[str, Combinable]) -> ValuesListIterable: ...
104-
def dates(self, field_name: str, kind: str, order: str = ...) -> QuerySet: ...
105-
def datetimes(self, field_name: str, kind: str, order: str = ..., tzinfo: None = ...) -> QuerySet: ...
106-
def none(self) -> QuerySet[_T]: ...
107-
def all(self) -> QuerySet[_T]: ...
108-
def filter(self, *args: Any, **kwargs: Any) -> QuerySet[_T]: ...
109-
def exclude(self, *args: Any, **kwargs: Any) -> QuerySet[_T]: ...
110-
def complex_filter(self, filter_obj: Any) -> QuerySet[_T]: ...
101+
def values(self, *fields: Union[str, Combinable], **expressions: Any) -> QuerySet[_T, Dict[str, Any]]: ...
102+
@overload
103+
def values_list(
104+
self, *fields: Union[str, Combinable], flat: Literal[False] = ..., named: Literal[True]
105+
) -> QuerySet[_T, NamedTuple]: ...
106+
@overload
107+
def values_list(
108+
self, *fields: Union[str, Combinable], flat: Literal[True], named: Literal[False] = ...
109+
) -> QuerySet[_T, Any]: ...
110+
@overload
111+
def values_list(
112+
self, *fields: Union[str, Combinable], flat: Literal[False] = ..., named: Literal[False] = ...
113+
) -> QuerySet[_T, Tuple]: ...
114+
def dates(self, field_name: str, kind: str, order: str = ...) -> QuerySet[_T, datetime.date]: ...
115+
def datetimes(
116+
self, field_name: str, kind: str, order: str = ..., tzinfo: None = ...
117+
) -> QuerySet[_T, datetime.datetime]: ...
118+
def none(self) -> QuerySet[_T, _Row]: ...
119+
def all(self) -> QuerySet[_T, _Row]: ...
120+
def filter(self, *args: Any, **kwargs: Any) -> QuerySet[_T, _Row]: ...
121+
def exclude(self, *args: Any, **kwargs: Any) -> QuerySet[_T, _Row]: ...
122+
def complex_filter(self, filter_obj: Any) -> QuerySet[_T, _Row]: ...
111123
def count(self) -> int: ...
112-
def union(self, *other_qs: Any, all: bool = ...) -> QuerySet[_T]: ...
113-
def intersection(self, *other_qs: Any) -> QuerySet[_T]: ...
114-
def difference(self, *other_qs: Any) -> QuerySet[_T]: ...
115-
def select_for_update(self, nowait: bool = ..., skip_locked: bool = ..., of: Tuple = ...) -> QuerySet: ...
116-
def select_related(self, *fields: Any) -> QuerySet[_T]: ...
117-
def prefetch_related(self, *lookups: Any) -> QuerySet[_T]: ...
118-
def annotate(self, *args: Any, **kwargs: Any) -> QuerySet[_T]: ...
119-
def order_by(self, *field_names: Any) -> QuerySet[_T]: ...
120-
def distinct(self, *field_names: Any) -> QuerySet[_T]: ...
124+
def union(self, *other_qs: Any, all: bool = ...) -> QuerySet[_T, _Row]: ...
125+
def intersection(self, *other_qs: Any) -> QuerySet[_T, _Row]: ...
126+
def difference(self, *other_qs: Any) -> QuerySet[_T, _Row]: ...
127+
def select_for_update(self, nowait: bool = ..., skip_locked: bool = ..., of: Tuple = ...) -> QuerySet[_T, _Row]: ...
128+
def select_related(self, *fields: Any) -> QuerySet[_T, _Row]: ...
129+
def prefetch_related(self, *lookups: Any) -> QuerySet[_T, _Row]: ...
130+
def annotate(self, *args: Any, **kwargs: Any) -> QuerySet[_T, _Row]: ...
131+
def order_by(self, *field_names: Any) -> QuerySet[_T, _Row]: ...
132+
def distinct(self, *field_names: Any) -> QuerySet[_T, _Row]: ...
121133
def extra(
122134
self,
123135
select: Optional[Dict[str, Any]] = ...,
@@ -126,11 +138,11 @@ class QuerySet(Iterable[_T], Sized):
126138
tables: Optional[List[str]] = ...,
127139
order_by: Optional[Sequence[str]] = ...,
128140
select_params: Optional[Sequence[Any]] = ...,
129-
) -> QuerySet[_T]: ...
130-
def reverse(self) -> QuerySet[_T]: ...
131-
def defer(self, *fields: Any) -> QuerySet[_T]: ...
132-
def only(self, *fields: Any) -> QuerySet[_T]: ...
133-
def using(self, alias: Optional[str]) -> QuerySet[_T]: ...
141+
) -> QuerySet[_T, _Row]: ...
142+
def reverse(self) -> QuerySet[_T, _Row]: ...
143+
def defer(self, *fields: Any) -> QuerySet[_T, _Row]: ...
144+
def only(self, *fields: Any) -> QuerySet[_T, _Row]: ...
145+
def using(self, alias: Optional[str]) -> QuerySet[_T, _Row]: ...
134146
@property
135147
def ordered(self) -> bool: ...
136148
@property
@@ -159,7 +171,7 @@ class RawQuerySet(Iterable[_T], Sized):
159171
@overload
160172
def __getitem__(self, k: str) -> Any: ...
161173
@overload
162-
def __getitem__(self, k: slice) -> QuerySet[_T]: ...
174+
def __getitem__(self, k: slice) -> RawQuerySet[_T]: ...
163175
@property
164176
def columns(self) -> List[str]: ...
165177
@property

django-stubs/shortcuts.pyi

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,6 @@ def redirect(
3131

3232
_T = TypeVar("_T", bound=Model)
3333

34-
def get_object_or_404(klass: Union[Type[_T], Manager[_T], QuerySet[_T]], *args: Any, **kwargs: Any) -> _T: ...
35-
def get_list_or_404(klass: Union[Type[_T], Manager[_T], QuerySet[_T]], *args: Any, **kwargs: Any) -> List[_T]: ...
34+
def get_object_or_404(klass: Union[Type[_T], Manager[_T], QuerySet[_T, _T]], *args: Any, **kwargs: Any) -> _T: ...
35+
def get_list_or_404(klass: Union[Type[_T], Manager[_T], QuerySet[_T, _T]], *args: Any, **kwargs: Any) -> List[_T]: ...
3636
def resolve_url(to: Union[Callable, Model, str], *args: Any, **kwargs: Any) -> str: ...

mypy_django_plugin/main.py

Lines changed: 31 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1+
from functools import partial
2+
13
import os
24
from typing import Callable, Dict, Optional, Union, cast
35

@@ -6,7 +8,7 @@
68
from mypy.options import Options
79
from mypy.plugin import (
810
AttributeContext, ClassDefContext, FunctionContext, MethodContext, Plugin,
9-
)
11+
AnalyzeTypeContext)
1012
from mypy.types import (
1113
AnyType, CallableType, Instance, NoneTyp, Type, TypeOfAny, TypeType, UnionType,
1214
)
@@ -80,6 +82,18 @@ def determine_proper_manager_type(ctx: FunctionContext) -> Type:
8082
return ret
8183

8284

85+
def set_first_generic_param_as_default_for_second(fullname: str, ctx: AnalyzeTypeContext) -> Type:
86+
if not ctx.type.args:
87+
return ctx.api.named_type(fullname, [AnyType(TypeOfAny.explicit),
88+
AnyType(TypeOfAny.explicit)])
89+
args = ctx.type.args
90+
if len(args) == 1:
91+
args = [args[0], args[0]]
92+
93+
analyzed_args = [ctx.api.analyze_type(arg) for arg in args]
94+
return ctx.api.named_type(fullname, analyzed_args)
95+
96+
8397
def return_user_model_hook(ctx: FunctionContext) -> Type:
8498
api = cast(TypeChecker, ctx.api)
8599
setting_expr = helpers.get_setting_expr(api, 'AUTH_USER_MODEL')
@@ -266,6 +280,14 @@ def _get_current_form_bases(self) -> Dict[str, int]:
266280
else:
267281
return {}
268282

283+
def _get_current_queryset_bases(self) -> Dict[str, int]:
284+
model_sym = self.lookup_fully_qualified(helpers.QUERYSET_CLASS_FULLNAME)
285+
if model_sym is not None and isinstance(model_sym.node, TypeInfo):
286+
return (helpers.get_django_metadata(model_sym.node)
287+
.setdefault('queryset_bases', {helpers.QUERYSET_CLASS_FULLNAME: 1}))
288+
else:
289+
return {}
290+
269291
def get_function_hook(self, fullname: str
270292
) -> Optional[Callable[[FunctionContext], Type]]:
271293
if fullname == 'django.contrib.auth.get_user_model':
@@ -344,6 +366,14 @@ def get_attribute_hook(self, fullname: str
344366

345367
return extract_and_return_primary_key_of_bound_related_field_parameter
346368

369+
def get_type_analyze_hook(self, fullname: str
370+
) -> Optional[Callable[[AnalyzeTypeContext], Type]]:
371+
queryset_bases = self._get_current_queryset_bases()
372+
if fullname in queryset_bases:
373+
return partial(set_first_generic_param_as_default_for_second, fullname)
374+
375+
return None
376+
347377

348378
def plugin(version):
349379
return DjangoPlugin

scripts/typecheck_tests.py

Lines changed: 5 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -91,20 +91,11 @@
9191
'Argument "is_dst" to "localize" of "BaseTzInfo" has incompatible type "None"; expected "bool"'
9292
],
9393
'aggregation': [
94-
'Incompatible types in assignment (expression has type "QuerySet[Any]", variable has type "List[Any]")',
9594
'"as_sql" undefined in superclass',
96-
'Incompatible types in assignment (expression has type "FlatValuesListIterable", '
97-
+ 'variable has type "ValuesListIterable")',
9895
'Incompatible type for "contact" of "Book" (got "Optional[Author]", expected "Union[Author, Combinable]")',
9996
'Incompatible type for "publisher" of "Book" (got "Optional[Publisher]", '
10097
+ 'expected "Union[Publisher, Combinable]")'
10198
],
102-
'aggregation_regress': [
103-
'Incompatible types in assignment (expression has type "List[str]", variable has type "QuerySet[Author]")',
104-
'Incompatible types in assignment (expression has type "FlatValuesListIterable", '
105-
+ 'variable has type "QuerySet[Any]")',
106-
'Too few arguments for "count" of "Sequence"'
107-
],
10899
'apps': [
109100
'Incompatible types in assignment (expression has type "str", target has type "type")',
110101
'"Callable[[bool, bool], List[Type[Model]]]" has no attribute "cache_clear"'
@@ -159,9 +150,6 @@
159150
'db_typecasts': [
160151
'"object" has no attribute "__iter__"; maybe "__str__" or "__dir__"? (not iterable)'
161152
],
162-
'expressions': [
163-
'Argument 1 to "Subquery" has incompatible type "Sequence[Dict[str, Any]]"; expected "QuerySet[Any]"'
164-
],
165153
'from_db_value': [
166154
'has no attribute "vendor"'
167155
],
@@ -199,9 +187,9 @@
199187
],
200188
'get_object_or_404': [
201189
'Argument 1 to "get_object_or_404" has incompatible type "str"; '
202-
+ 'expected "Union[Type[<nothing>], QuerySet[<nothing>]]"',
190+
+ 'expected "Union[Type[<nothing>], QuerySet[<nothing>, <nothing>]]"',
203191
'Argument 1 to "get_list_or_404" has incompatible type "List[Type[Article]]"; '
204-
+ 'expected "Union[Type[<nothing>], QuerySet[<nothing>]]"',
192+
+ 'expected "Union[Type[<nothing>], QuerySet[<nothing>, <nothing>]]"',
205193
'CustomClass'
206194
],
207195
'get_or_create': [
@@ -227,10 +215,6 @@
227215
'many_to_one': [
228216
'Incompatible type for "parent" of "Child" (got "None", expected "Union[Parent, Combinable]")'
229217
],
230-
'model_inheritance_regress': [
231-
'Incompatible types in assignment (expression has type "List[Supplier]", '
232-
+ 'variable has type "QuerySet[Supplier]")'
233-
],
234218
'model_meta': [
235219
'"object" has no attribute "items"',
236220
'"Field" has no attribute "many_to_many"'
@@ -305,7 +289,8 @@
305289
],
306290
'queries': [
307291
'Incompatible types in assignment (expression has type "None", variable has type "str")',
308-
'Invalid index type "Optional[str]" for "Dict[str, int]"; expected type "str"'
292+
'Invalid index type "Optional[str]" for "Dict[str, int]"; expected type "str"',
293+
'No overload variant of "values_list" of "QuerySet" matches argument types "str", "bool", "bool"',
309294
],
310295
'requests': [
311296
'Incompatible types in assignment (expression has type "Dict[str, str]", variable has type "QueryDict")'
@@ -314,7 +299,7 @@
314299
'Argument 1 to "TextIOWrapper" has incompatible type "HttpResponse"; expected "IO[bytes]"'
315300
],
316301
'prefetch_related': [
317-
'Incompatible types in assignment (expression has type "List[Room]", variable has type "QuerySet[Room]")',
302+
'Incompatible types in assignment (expression has type "List[Room]", variable has type "QuerySet[Room, Room]")',
318303
'"None" has no attribute "__iter__"',
319304
'has no attribute "read_by"'
320305
],

0 commit comments

Comments
 (0)