Skip to content

Commit 5b455b7

Browse files
syastrovmkurnikov
authored andcommitted
Specific return types for values and values list (#53)
* Instead of using Literal types, overload QuerySet.values_list in the plugin. Fixes #43. - Add a couple of extra type checks that Django makes: 1) 'flat' and 'named' can't be used together. 2) 'flat' is not valid when values_list is called with more than one field. * Determine better row types for values_list/values based on fields specified. - In the case of values_list, we use a Row type with either a single primitive, Tuple, or NamedTuple. - In the case of values, we use a TypedDict. - In both cases, Any is used as a fallback for individual fields if those fields cannot be resolved. A couple other fixes I made along the way: - Don't create reverse relation for ForeignKeys with related_name='+' - Don't skip creating other related managers in AddRelatedManagers if a dynamic value is encountered for related_name parameter, or if the type cannot be determined. * Fix for TypedDict so that they are considered anonymous. * Clean up some comments. * Implement making TypedDict anonymous in a way that doesn't crash sometimes. * Fix flake8 errors. * Remove even uglier hack about making TypedDict anonymous. * Address review comments. Write a few better comments inside tests. * Fix crash when running with mypyc ("interpreted classes cannot inherit from compiled") due to the way I extended TypedDictType. - Implemented the hack in another way that works on mypyc. - Added a couple extra tests of accessing 'id' / 'pk' via values_list. * Fix flake8 errors. * Support annotation expressions (use type Any) for TypedDicts row types returned by values_list. - Bonus points: handle values_list gracefully (use type Any) where Tuples are returned where some of the fields arguments are not string literals.
1 parent 5c6be7a commit 5b455b7

File tree

11 files changed

+649
-73
lines changed

11 files changed

+649
-73
lines changed

django-stubs/__init__.pyi

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,12 @@
1-
from typing import Any
1+
from typing import Any, NamedTuple
22
from .utils.version import get_version as get_version
33

44
VERSION: Any
55
__version__: str
66

77
def setup(set_prefix: bool = ...) -> None: ...
8+
9+
# Used by mypy_django_plugin when returning a QuerySet row that is a NamedTuple where the field names are unknown
10+
class _NamedTupleAnyAttr(NamedTuple):
11+
def __getattr__(self, item: str) -> Any: ...
12+
def __setattr__(self, item: str, value: Any) -> None: ...

django-stubs/db/models/query.pyi

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -97,8 +97,9 @@ class QuerySet(Generic[_T, _Row], Collection[_Row], Sized):
9797
def raw(
9898
self, raw_query: str, params: Any = ..., translations: Optional[Dict[str, str]] = ..., using: None = ...
9999
) -> RawQuerySet: ...
100+
# The type of values may be overridden to be more specific in the mypy plugin, depending on the fields param
100101
def values(self, *fields: Union[str, Combinable], **expressions: Any) -> QuerySet[_T, Dict[str, Any]]: ...
101-
# The type of values_list is overridden to be more specific in the mypy django plugin
102+
# The type of values_list may be overridden to be more specific in the mypy plugin, depending on the fields param
102103
def values_list(
103104
self, *fields: Union[str, Combinable], flat: bool = ..., named: bool = ...
104105
) -> QuerySet[_T, Any]: ...

mypy_django_plugin/helpers.py

Lines changed: 92 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,16 @@
11
import typing
2-
from typing import Dict, Optional
2+
from collections import OrderedDict
3+
from typing import Dict, Optional, cast
34

4-
from mypy.checker import TypeChecker
5+
from mypy.checker import TypeChecker, gen_unique_name
6+
from mypy.mro import calculate_mro
57
from mypy.nodes import (
68
AssignmentStmt, ClassDef, Expression, ImportedName, Lvalue, MypyFile, NameExpr, SymbolNode, TypeInfo,
7-
)
9+
SymbolTable, SymbolTableNode, Block, GDEF, MDEF, Var)
810
from mypy.plugin import FunctionContext, MethodContext
911
from mypy.types import (
1012
AnyType, Instance, NoneTyp, Type, TypeOfAny, TypeVarType, UnionType,
11-
)
13+
TupleType, TypedDictType)
1214

1315
MODEL_CLASS_FULLNAME = 'django.db.models.base.Model'
1416
FIELD_FULLNAME = 'django.db.models.fields.Field'
@@ -211,7 +213,7 @@ def extract_field_setter_type(tp: Instance) -> Optional[Type]:
211213
return None
212214

213215

214-
def extract_field_getter_type(tp: Instance) -> Optional[Type]:
216+
def extract_field_getter_type(tp: Type) -> Optional[Type]:
215217
if not isinstance(tp, Instance):
216218
return None
217219
if tp.type.has_base(FIELD_FULLNAME):
@@ -235,6 +237,10 @@ def get_fields_metadata(model: TypeInfo) -> Dict[str, typing.Any]:
235237
return get_django_metadata(model).setdefault('fields', {})
236238

237239

240+
def get_lookups_metadata(model: TypeInfo) -> Dict[str, typing.Any]:
241+
return get_django_metadata(model).setdefault('lookups', {})
242+
243+
238244
def extract_explicit_set_type_of_model_primary_key(model: TypeInfo) -> Optional[Type]:
239245
"""
240246
If field with primary_key=True is set on the model, extract its __set__ type.
@@ -296,3 +302,84 @@ def get_assigned_value_for_class(type_info: TypeInfo, name: str) -> Optional[Exp
296302
if isinstance(lvalue, NameExpr) and lvalue.name == name:
297303
return rvalue
298304
return None
305+
306+
307+
def is_field_nullable(model: TypeInfo, field_name: str) -> bool:
308+
return get_fields_metadata(model).get(field_name, {}).get('null', False)
309+
310+
311+
def is_foreign_key(t: Type) -> bool:
312+
if not isinstance(t, Instance):
313+
return False
314+
return has_any_of_bases(t.type, (FOREIGN_KEY_FULLNAME, ONETOONE_FIELD_FULLNAME))
315+
316+
317+
def build_class_with_annotated_fields(api: TypeChecker, base: Type, fields: 'OrderedDict[str, Type]',
318+
name: str) -> Instance:
319+
"""Build an Instance with `name` that contains the specified `fields` as attributes and extends `base`."""
320+
# Credit: This code is largely copied/modified from TypeChecker.intersect_instance_callable and
321+
# NamedTupleAnalyzer.build_namedtuple_typeinfo
322+
323+
cur_module = cast(MypyFile, api.scope.stack[0])
324+
gen_name = gen_unique_name(name, cur_module.names)
325+
326+
cdef = ClassDef(name, Block([]))
327+
cdef.fullname = cur_module.fullname() + '.' + gen_name
328+
info = TypeInfo(SymbolTable(), cdef, cur_module.fullname())
329+
cdef.info = info
330+
info.bases = [base]
331+
332+
def add_field(var: Var, is_initialized_in_class: bool = False,
333+
is_property: bool = False) -> None:
334+
var.info = info
335+
var.is_initialized_in_class = is_initialized_in_class
336+
var.is_property = is_property
337+
var._fullname = '%s.%s' % (info.fullname(), var.name())
338+
info.names[var.name()] = SymbolTableNode(MDEF, var)
339+
340+
vars = [Var(item, typ) for item, typ in fields.items()]
341+
for var in vars:
342+
add_field(var, is_property=True)
343+
344+
calculate_mro(info)
345+
info.calculate_metaclass_type()
346+
347+
cur_module.names[gen_name] = SymbolTableNode(GDEF, info, plugin_generated=True)
348+
return Instance(info, [])
349+
350+
351+
def make_named_tuple(api: TypeChecker, fields: 'OrderedDict[str, Type]', name: str) -> Type:
352+
if not fields:
353+
# No fields specified, so fallback to a subclass of NamedTuple that allows
354+
# __getattr__ / __setattr__ for any attribute name.
355+
fallback = api.named_generic_type('django._NamedTupleAnyAttr', [])
356+
else:
357+
fallback = build_class_with_annotated_fields(
358+
api=api,
359+
base=api.named_generic_type('typing.NamedTuple', []),
360+
fields=fields,
361+
name=name
362+
)
363+
return TupleType(list(fields.values()), fallback=fallback)
364+
365+
366+
def make_typeddict(api: TypeChecker, fields: 'OrderedDict[str, Type]', required_keys: typing.Set[str]) -> Type:
367+
object_type = api.named_generic_type('mypy_extensions._TypedDict', [])
368+
typed_dict_type = TypedDictType(fields, required_keys=required_keys, fallback=object_type)
369+
return typed_dict_type
370+
371+
372+
def make_tuple(api: TypeChecker, fields: typing.List[Type]) -> Type:
373+
implicit_any = AnyType(TypeOfAny.special_form)
374+
fallback = api.named_generic_type('builtins.tuple', [implicit_any])
375+
return TupleType(fields, fallback=fallback)
376+
377+
378+
def get_private_descriptor_type(type_info: TypeInfo, private_field_name: str, is_nullable: bool) -> Type:
379+
node = type_info.get(private_field_name).node
380+
if isinstance(node, Var):
381+
descriptor_type = node.type
382+
if is_nullable:
383+
descriptor_type = make_optional(descriptor_type)
384+
return descriptor_type
385+
return AnyType(TypeOfAny.unannotated)

mypy_django_plugin/lookups.py

Lines changed: 150 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,150 @@
1+
import dataclasses
2+
from typing import Union, List
3+
4+
from mypy.nodes import TypeInfo
5+
from mypy.plugin import CheckerPluginInterface
6+
from mypy.types import Type, Instance
7+
8+
from mypy_django_plugin import helpers
9+
10+
11+
@dataclasses.dataclass
12+
class RelatedModelNode:
13+
typ: Instance
14+
is_nullable: bool
15+
16+
17+
@dataclasses.dataclass
18+
class FieldNode:
19+
typ: Type
20+
21+
22+
LookupNode = Union[RelatedModelNode, FieldNode]
23+
24+
25+
class LookupException(Exception):
26+
pass
27+
28+
29+
def resolve_lookup(api: CheckerPluginInterface, model_type_info: TypeInfo, lookup: str) -> List[LookupNode]:
30+
"""Resolve a lookup str to a list of LookupNodes.
31+
32+
Each node represents a part of the lookup (separated by "__"), in order.
33+
Each node is the Model or Field that was resolved.
34+
35+
Raises LookupException if there were any issues resolving the lookup.
36+
"""
37+
lookup_parts = lookup.split("__")
38+
39+
nodes = []
40+
while lookup_parts:
41+
lookup_part = lookup_parts.pop(0)
42+
43+
if not nodes:
44+
current_node = None
45+
else:
46+
current_node = nodes[-1]
47+
48+
if current_node is None:
49+
new_node = resolve_model_lookup(api, model_type_info, lookup_part)
50+
elif isinstance(current_node, RelatedModelNode):
51+
new_node = resolve_model_lookup(api, current_node.typ.type, lookup_part)
52+
elif isinstance(current_node, FieldNode):
53+
raise LookupException(f"Field lookups not yet supported for lookup {lookup}")
54+
else:
55+
raise LookupException(f"Unsupported node type: {type(current_node)}")
56+
nodes.append(new_node)
57+
return nodes
58+
59+
60+
def resolve_model_lookup(api: CheckerPluginInterface, model_type_info: TypeInfo,
61+
lookup: str) -> LookupNode:
62+
"""Resolve a lookup on the given model."""
63+
if lookup == 'pk':
64+
# Primary keys are special-cased
65+
primary_key_type = helpers.extract_primary_key_type_for_get(model_type_info)
66+
if primary_key_type:
67+
return FieldNode(primary_key_type)
68+
else:
69+
# No PK, use the get type for AutoField as PK type.
70+
autofield_info = api.lookup_typeinfo('django.db.models.fields.AutoField')
71+
pk_type = helpers.get_private_descriptor_type(autofield_info, '_pyi_private_get_type',
72+
is_nullable=False)
73+
return FieldNode(pk_type)
74+
75+
field_name = get_actual_field_name_for_lookup_field(lookup, model_type_info)
76+
77+
field_node = model_type_info.get(field_name)
78+
if not field_node:
79+
raise LookupException(
80+
f'When resolving lookup "{lookup}", field "{field_name}" was not found in model {model_type_info.name()}')
81+
82+
if field_name.endswith('_id'):
83+
field_name_without_id = field_name.rstrip('_id')
84+
foreign_key_field = model_type_info.get(field_name_without_id)
85+
if foreign_key_field is not None and helpers.is_foreign_key(foreign_key_field.type):
86+
# Hack: If field ends with '_id' and there is a model field without the '_id' suffix, then use that field.
87+
field_node = foreign_key_field
88+
field_name = field_name_without_id
89+
90+
field_node_type = field_node.type
91+
if field_node_type is None or not isinstance(field_node_type, Instance):
92+
raise LookupException(
93+
f'When resolving lookup "{lookup}", could not determine type for {model_type_info.name()}.{field_name}')
94+
95+
if helpers.is_foreign_key(field_node_type):
96+
field_type = helpers.extract_field_getter_type(field_node_type)
97+
is_nullable = helpers.is_optional(field_type)
98+
if is_nullable:
99+
field_type = helpers.make_required(field_type)
100+
101+
if isinstance(field_type, Instance):
102+
return RelatedModelNode(typ=field_type, is_nullable=is_nullable)
103+
else:
104+
raise LookupException(f"Not an instance for field {field_type} lookup {lookup}")
105+
106+
field_type = helpers.extract_field_getter_type(field_node_type)
107+
108+
if field_type:
109+
return FieldNode(typ=field_type)
110+
else:
111+
# Not a Field
112+
if field_name == 'id':
113+
# If no 'id' field was fouond, use an int
114+
return FieldNode(api.named_generic_type('builtins.int', []))
115+
116+
related_manager_arg = None
117+
if field_node_type.type.has_base(helpers.RELATED_MANAGER_CLASS_FULLNAME):
118+
related_manager_arg = field_node_type.args[0]
119+
120+
if related_manager_arg is not None:
121+
# Reverse relation
122+
return RelatedModelNode(typ=related_manager_arg, is_nullable=True)
123+
raise LookupException(
124+
f'When resolving lookup "{lookup}", could not determine type for {model_type_info.name()}.{field_name}')
125+
126+
127+
def get_actual_field_name_for_lookup_field(lookup: str, model_type_info: TypeInfo) -> str:
128+
"""Attempt to find out the real field name if this lookup is a related_query_name (for reverse relations).
129+
130+
If it's not, return the original lookup.
131+
"""
132+
lookups_metadata = helpers.get_lookups_metadata(model_type_info)
133+
lookup_metadata = lookups_metadata.get(lookup)
134+
if lookup_metadata is None:
135+
# If not found on current model, look in all bases for their lookup metadata
136+
for base in model_type_info.mro:
137+
lookups_metadata = helpers.get_lookups_metadata(base)
138+
lookup_metadata = lookups_metadata.get(lookup)
139+
if lookup_metadata:
140+
break
141+
if not lookup_metadata:
142+
lookup_metadata = {}
143+
related_name = lookup_metadata.get('related_query_name_target', None)
144+
if related_name:
145+
# If the lookup is a related lookup, then look at the field specified by related_name.
146+
# This is to support if related_query_name is set and differs from.
147+
field_name = related_name
148+
else:
149+
field_name = lookup
150+
return field_name

mypy_django_plugin/main.py

Lines changed: 9 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,5 @@
1-
from functools import partial
2-
31
import os
2+
from functools import partial
43
from typing import Callable, Dict, Optional, Union, cast
54

65
from mypy.checker import TypeChecker
@@ -23,6 +22,7 @@
2322
determine_model_cls_from_string_for_migrations, get_string_value_from_expr,
2423
)
2524
from mypy_django_plugin.transformers.models import process_model_class
25+
from mypy_django_plugin.transformers.queryset import extract_proper_type_for_values_and_values_list
2626
from mypy_django_plugin.transformers.settings import (
2727
AddSettingValuesToDjangoConfObject, get_settings_metadata,
2828
)
@@ -165,7 +165,7 @@ def extract_and_return_primary_key_of_bound_related_field_parameter(ctx: Attribu
165165
if primary_key_type:
166166
return primary_key_type
167167

168-
is_nullable = helpers.get_fields_metadata(ctx.type.type).get(field_name, {}).get('null', False)
168+
is_nullable = helpers.is_field_nullable(ctx.type.type, field_name)
169169
if is_nullable:
170170
return helpers.make_optional(ctx.default_attr_type)
171171

@@ -292,7 +292,10 @@ def __init__(self, options: Options) -> None:
292292
if self.django_settings_module:
293293
settings_modules.append(self.django_settings_module)
294294

295-
monkeypatch.add_modules_as_a_source_seed_files(settings_modules)
295+
auto_imports = ['mypy_extensions']
296+
auto_imports.extend(settings_modules)
297+
298+
monkeypatch.add_modules_as_a_source_seed_files(auto_imports)
296299
monkeypatch.inject_modules_as_dependencies_for_django_conf_settings(settings_modules)
297300

298301
def _get_current_model_bases(self) -> Dict[str, int]:
@@ -359,10 +362,10 @@ def get_method_hook(self, fullname: str
359362
if sym and isinstance(sym.node, TypeInfo) and sym.node.has_base(helpers.FORM_MIXIN_CLASS_FULLNAME):
360363
return extract_proper_type_for_get_form
361364

362-
if method_name == 'values_list':
365+
if method_name in ('values', 'values_list'):
363366
sym = self.lookup_fully_qualified(class_name)
364367
if sym and isinstance(sym.node, TypeInfo) and sym.node.has_base(helpers.QUERYSET_CLASS_FULLNAME):
365-
return extract_proper_type_for_values_list
368+
return partial(extract_proper_type_for_values_and_values_list, method_name)
366369

367370
if fullname in {'django.apps.registry.Apps.get_model',
368371
'django.db.migrations.state.StateApps.get_model'}:

mypy_django_plugin/transformers/fields.py

Lines changed: 6 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,10 @@
11
from typing import Optional, cast
22

33
from mypy.checker import TypeChecker
4-
from mypy.nodes import ListExpr, NameExpr, StrExpr, TupleExpr, TypeInfo, Var
4+
from mypy.nodes import ListExpr, NameExpr, StrExpr, TupleExpr, TypeInfo
55
from mypy.plugin import FunctionContext
66
from mypy.types import (
7-
AnyType, CallableType, Instance, TupleType, Type, TypeOfAny, UnionType,
7+
AnyType, CallableType, Instance, TupleType, Type, UnionType,
88
)
99

1010
from mypy_django_plugin import helpers
@@ -88,23 +88,13 @@ def fill_descriptor_types_for_related_field(ctx: FunctionContext) -> Type:
8888
return helpers.reparametrize_instance(ctx.default_return_type, new_args=args)
8989

9090

91-
def get_private_descriptor_type(type_info: TypeInfo, private_field_name: str, is_nullable: bool) -> Type:
92-
node = type_info.get(private_field_name).node
93-
if isinstance(node, Var):
94-
descriptor_type = node.type
95-
if is_nullable:
96-
descriptor_type = helpers.make_optional(descriptor_type)
97-
return descriptor_type
98-
return AnyType(TypeOfAny.unannotated)
99-
100-
10191
def set_descriptor_types_for_field(ctx: FunctionContext) -> Instance:
10292
default_return_type = cast(Instance, ctx.default_return_type)
10393
is_nullable = helpers.parse_bool(helpers.get_argument_by_name(ctx, 'null'))
104-
set_type = get_private_descriptor_type(default_return_type.type, '_pyi_private_set_type',
105-
is_nullable=is_nullable)
106-
get_type = get_private_descriptor_type(default_return_type.type, '_pyi_private_get_type',
107-
is_nullable=is_nullable)
94+
set_type = helpers.get_private_descriptor_type(default_return_type.type, '_pyi_private_set_type',
95+
is_nullable=is_nullable)
96+
get_type = helpers.get_private_descriptor_type(default_return_type.type, '_pyi_private_get_type',
97+
is_nullable=is_nullable)
10898
return helpers.reparametrize_instance(default_return_type, [set_type, get_type])
10999

110100

0 commit comments

Comments
 (0)