Skip to content

Commit b7f7713

Browse files
committed
add support for get_user_model(), fixes #16
1 parent 2720b74 commit b7f7713

File tree

5 files changed

+128
-28
lines changed

5 files changed

+128
-28
lines changed

mypy_django_plugin/helpers.py

Lines changed: 40 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,9 @@
11
import typing
22
from typing import Dict, Optional
33

4-
from mypy.nodes import Expression, ImportedName, MypyFile, NameExpr, SymbolNode, TypeInfo
4+
from mypy.checker import TypeChecker
5+
from mypy.nodes import AssignmentStmt, Expression, ImportedName, Lvalue, MypyFile, NameExpr, Statement, SymbolNode, TypeInfo, \
6+
ClassDef
57
from mypy.plugin import FunctionContext
68
from mypy.types import AnyType, Instance, Type, TypeOfAny, TypeVarType
79

@@ -146,3 +148,40 @@ def get_argument_type_by_name(ctx: FunctionContext, name: str) -> Optional[Type]
146148
# Either an error or no value passed.
147149
return None
148150
return arg_types[0]
151+
152+
153+
def get_setting_expr(api: TypeChecker, setting_name: str) -> Optional[Expression]:
154+
try:
155+
settings_sym = api.modules['django.conf'].names['settings']
156+
except KeyError:
157+
return None
158+
159+
settings_type: TypeInfo = settings_sym.type.type
160+
auth_user_model_sym = settings_type.get(setting_name)
161+
if not auth_user_model_sym:
162+
return None
163+
164+
module, _, name = auth_user_model_sym.fullname.rpartition('.')
165+
if module not in api.modules:
166+
return None
167+
168+
module_file = api.modules.get(module)
169+
for name_expr, value_expr in iter_over_assignments(module_file):
170+
if isinstance(name_expr, NameExpr) and name_expr.name == setting_name:
171+
return value_expr
172+
return None
173+
174+
175+
def iter_over_assignments(class_or_module: typing.Union[ClassDef, MypyFile]) -> typing.Iterator[typing.Tuple[Lvalue, Expression]]:
176+
if isinstance(class_or_module, ClassDef):
177+
statements = class_or_module.defs.body
178+
else:
179+
statements = class_or_module.defs
180+
181+
for stmt in statements:
182+
if not isinstance(stmt, AssignmentStmt):
183+
continue
184+
if len(stmt.lvalues) > 1:
185+
# not supported yet
186+
continue
187+
yield stmt.lvalues[0], stmt.rvalue

mypy_django_plugin/main.py

Lines changed: 31 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,13 +5,13 @@
55
from mypy.nodes import TypeInfo
66
from mypy.options import Options
77
from mypy.plugin import ClassDefContext, FunctionContext, MethodContext, Plugin
8-
from mypy.types import Instance, Type
8+
from mypy.types import Instance, Type, TypeType
99

1010
from mypy_django_plugin import helpers, monkeypatch
1111
from mypy_django_plugin.config import Config
1212
from mypy_django_plugin.plugins.fields import determine_type_of_array_field, record_field_properties_into_outer_model_class
1313
from mypy_django_plugin.plugins.init_create import redefine_and_typecheck_model_init, redefine_and_typecheck_model_create
14-
from mypy_django_plugin.plugins.migrations import determine_model_cls_from_string_for_migrations
14+
from mypy_django_plugin.plugins.migrations import determine_model_cls_from_string_for_migrations, get_string_value_from_expr
1515
from mypy_django_plugin.plugins.models import process_model_class
1616
from mypy_django_plugin.plugins.related_fields import extract_to_parameter_as_get_ret_type_for_related_field, reparametrize_with
1717
from mypy_django_plugin.plugins.settings import AddSettingValuesToDjangoConfObject
@@ -56,6 +56,32 @@ def determine_proper_manager_type(ctx: FunctionContext) -> Type:
5656
return ret
5757

5858

59+
def return_user_model_hook(ctx: FunctionContext) -> Type:
60+
api = cast(TypeChecker, ctx.api)
61+
setting_expr = helpers.get_setting_expr(api, 'AUTH_USER_MODEL')
62+
if setting_expr is None:
63+
return ctx.default_return_type
64+
65+
app_label, _, model_class_name = get_string_value_from_expr(setting_expr).rpartition('.')
66+
if app_label is None:
67+
return ctx.default_return_type
68+
69+
model_fullname = helpers.get_model_fullname(app_label, model_class_name,
70+
all_modules=api.modules)
71+
if model_fullname is None:
72+
api.fail(f'"{app_label}.{model_class_name}" model class is not imported so far. Try to import it '
73+
f'(under if TYPE_CHECKING) at the beginning of the current file',
74+
context=ctx.context)
75+
return ctx.default_return_type
76+
77+
model_info = helpers.lookup_fully_qualified_generic(model_fullname,
78+
all_modules=api.modules)
79+
if model_info is None or not isinstance(model_info, TypeInfo):
80+
return ctx.default_return_type
81+
return TypeType(Instance(model_info, []))
82+
83+
84+
5985
class DjangoPlugin(Plugin):
6086
def __init__(self, options: Options) -> None:
6187
super().__init__(options)
@@ -105,6 +131,9 @@ def _get_current_manager_bases(self) -> Dict[str, int]:
105131

106132
def get_function_hook(self, fullname: str
107133
) -> Optional[Callable[[FunctionContext], Type]]:
134+
if fullname == 'django.contrib.auth.get_user_model':
135+
return return_user_model_hook
136+
108137
if fullname in {helpers.FOREIGN_KEY_FULLNAME,
109138
helpers.ONETOONE_FIELD_FULLNAME,
110139
helpers.MANYTOMANY_FIELD_FULLNAME}:

mypy_django_plugin/plugins/models.py

Lines changed: 1 addition & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
from mypy.types import AnyType, Instance, NoneTyp, TypeOfAny
1111

1212
from mypy_django_plugin import helpers
13+
from mypy_django_plugin.helpers import iter_over_assignments
1314

1415

1516
@dataclasses.dataclass
@@ -55,16 +56,6 @@ def run(self) -> None:
5556
raise NotImplementedError()
5657

5758

58-
def iter_over_assignments(klass: ClassDef) -> Iterator[Tuple[Lvalue, Expression]]:
59-
for stmt in klass.defs.body:
60-
if not isinstance(stmt, AssignmentStmt):
61-
continue
62-
if len(stmt.lvalues) > 1:
63-
# not supported yet
64-
continue
65-
yield stmt.lvalues[0], stmt.rvalue
66-
67-
6859
def iter_call_assignments(klass: ClassDef) -> Iterator[Tuple[Lvalue, CallExpr]]:
6960
for lvalue, rvalue in iter_over_assignments(klass):
7061
if isinstance(rvalue, CallExpr):

test-data/typecheck/managers.test

Lines changed: 0 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -136,19 +136,4 @@ class AbstractBase2(models.Model):
136136

137137
class Child(AbstractBase1, AbstractBase2):
138138
pass
139-
[out]
140-
141-
[CASE get_object_or_404_returns_proper_types]
142-
from django.shortcuts import get_object_or_404, get_list_or_404
143-
from django.db import models
144-
145-
class MyModel(models.Model):
146-
pass
147-
reveal_type(get_object_or_404(MyModel)) # E: Revealed type is 'main.MyModel*'
148-
reveal_type(get_object_or_404(MyModel.objects)) # E: Revealed type is 'main.MyModel*'
149-
reveal_type(get_object_or_404(MyModel.objects.get_queryset())) # E: Revealed type is 'main.MyModel*'
150-
151-
reveal_type(get_list_or_404(MyModel)) # E: Revealed type is 'builtins.list[main.MyModel*]'
152-
reveal_type(get_list_or_404(MyModel.objects)) # E: Revealed type is 'builtins.list[main.MyModel*]'
153-
reveal_type(get_list_or_404(MyModel.objects.get_queryset())) # E: Revealed type is 'builtins.list[main.MyModel*]'
154139
[out]

test-data/typecheck/shortcuts.test

Lines changed: 56 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,56 @@
1+
[CASE get_object_or_404_returns_proper_types]
2+
from django.shortcuts import get_object_or_404, get_list_or_404
3+
from django.db import models
4+
5+
class MyModel(models.Model):
6+
pass
7+
reveal_type(get_object_or_404(MyModel)) # E: Revealed type is 'main.MyModel*'
8+
reveal_type(get_object_or_404(MyModel.objects)) # E: Revealed type is 'main.MyModel*'
9+
reveal_type(get_object_or_404(MyModel.objects.get_queryset())) # E: Revealed type is 'main.MyModel*'
10+
11+
reveal_type(get_list_or_404(MyModel)) # E: Revealed type is 'builtins.list[main.MyModel*]'
12+
reveal_type(get_list_or_404(MyModel.objects)) # E: Revealed type is 'builtins.list[main.MyModel*]'
13+
reveal_type(get_list_or_404(MyModel.objects.get_queryset())) # E: Revealed type is 'builtins.list[main.MyModel*]'
14+
[out]
15+
16+
[CASE get_user_model_returns_proper_class]
17+
from typing import TYPE_CHECKING
18+
if TYPE_CHECKING:
19+
from myapp.models import MyUser
20+
from django.contrib.auth import get_user_model
21+
22+
UserModel = get_user_model()
23+
reveal_type(UserModel.objects) # E: Revealed type is 'django.db.models.manager.Manager[myapp.models.MyUser]'
24+
25+
[env DJANGO_SETTINGS_MODULE=mysettings]
26+
[file mysettings.py]
27+
INSTALLED_APPS = ('myapp',)
28+
AUTH_USER_MODEL = 'myapp.MyUser'
29+
30+
[file myapp/__init__.py]
31+
[file myapp/models.py]
32+
from django.db import models
33+
class MyUser(models.Model):
34+
pass
35+
[out]
36+
37+
[CASE return_type_model_and_show_error_if_model_not_yet_imported]
38+
from django.contrib.auth import get_user_model
39+
40+
UserModel = get_user_model()
41+
reveal_type(UserModel.objects)
42+
43+
[env DJANGO_SETTINGS_MODULE=mysettings]
44+
[file mysettings.py]
45+
INSTALLED_APPS = ('myapp',)
46+
AUTH_USER_MODEL = 'myapp.MyUser'
47+
48+
[file myapp/__init__.py]
49+
[file myapp/models.py]
50+
from django.db import models
51+
class MyUser(models.Model):
52+
pass
53+
[out]
54+
main:3: error: "myapp.MyUser" model class is not imported so far. Try to import it (under if TYPE_CHECKING) at the beginning of the current file
55+
main:4: error: Revealed type is 'Any'
56+
main:4: error: "Type[Model]" has no attribute "objects"

0 commit comments

Comments
 (0)