Skip to content

Commit df021f6

Browse files
committed
add some support for proxy models
1 parent 5779607 commit df021f6

File tree

6 files changed

+61
-29
lines changed

6 files changed

+61
-29
lines changed

django-stubs/contrib/auth/base_user.pyi

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,16 @@
1-
from typing import Any, Optional, Tuple, List, overload
1+
from typing import Any, Optional, Tuple, List, overload, TypeVar
2+
3+
from django.db.models.base import Model
24

35
from django.db import models
46

5-
class BaseUserManager(models.Manager):
7+
_T = TypeVar('_T', bound=Model)
8+
9+
class BaseUserManager(models.Manager[_T]):
610
@classmethod
711
def normalize_email(cls, email: Optional[str]) -> str: ...
812
def make_random_password(self, length: int = ..., allowed_chars: str = ...) -> str: ...
9-
def get_by_natural_key(self, username: Optional[str]) -> AbstractBaseUser: ...
13+
def get_by_natural_key(self, username: Optional[str]) -> _T: ...
1014

1115
class AbstractBaseUser(models.Model):
1216
password: models.CharField = ...

django-stubs/contrib/auth/models.pyi

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,8 @@
1-
from typing import Any, Collection, Optional, Set, Tuple, Type, Union
1+
from typing import Any, List, Optional, Set, Tuple, Type, Union, TypeVar
22

33
from django.contrib.auth.base_user import AbstractBaseUser as AbstractBaseUser, BaseUserManager as BaseUserManager
44
from django.contrib.contenttypes.models import ContentType
5+
from django.db.models.base import Model
56
from django.db.models.manager import EmptyManager
67

78
from django.contrib.auth.validators import UnicodeUsernameValidator
@@ -27,13 +28,15 @@ class Group(models.Model):
2728
permissions: models.ManyToManyField = models.ManyToManyField(Permission)
2829
def natural_key(self): ...
2930

30-
class UserManager(BaseUserManager):
31+
_T = TypeVar('_T', bound=Model)
32+
33+
class UserManager(BaseUserManager[_T]):
3134
def create_user(
3235
self, username: str, email: Optional[str] = ..., password: Optional[str] = ..., **extra_fields: Any
33-
) -> AbstractBaseUser: ...
36+
) -> _T: ...
3437
def create_superuser(
3538
self, username: str, email: Optional[str], password: Optional[str], **extra_fields: Any
36-
) -> AbstractBaseUser: ...
39+
) -> _T: ...
3740

3841
class PermissionsMixin(models.Model):
3942
is_superuser: models.BooleanField = ...

mypy_django_plugin/django/context.py

Lines changed: 11 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,22 +1,19 @@
11
import os
22
from collections import defaultdict
33
from contextlib import contextmanager
4-
from typing import (
5-
TYPE_CHECKING, Any, Dict, Iterator, List, Optional, Tuple, Type,
6-
)
4+
from typing import (Any, Dict, Iterator, List, Optional, TYPE_CHECKING, Tuple, Type)
75

8-
from django.contrib.postgres.fields import ArrayField
96
from django.core.exceptions import FieldError
107
from django.db.models.base import Model
11-
from django.db.models.fields import CharField, Field, AutoField
128
from django.db.models.fields.related import ForeignKey, RelatedField
139
from django.db.models.fields.reverse_related import ForeignObjectRel
1410
from django.db.models.sql.query import Query
1511
from django.utils.functional import cached_property
1612
from mypy.checker import TypeChecker
17-
from mypy.types import Instance
18-
from mypy.types import Type as MypyType
13+
from mypy.types import Instance, Type as MypyType
1914

15+
from django.contrib.postgres.fields import ArrayField
16+
from django.db.models.fields import AutoField, CharField, Field
2017
from mypy_django_plugin.lib import helpers
2118

2219
if TYPE_CHECKING:
@@ -210,13 +207,19 @@ def get_expected_types(self, api: TypeChecker, model_cls: Type[Model], method: s
210207
if isinstance(field, ForeignKey):
211208
field_name = field.name
212209
foreign_key_info = helpers.lookup_class_typeinfo(api, field.__class__)
213-
related_model_info = helpers.lookup_class_typeinfo(api, field.related_model)
210+
211+
related_model = field.related_model
212+
if related_model._meta.proxy_for_model:
213+
related_model = field.related_model._meta.proxy_for_model
214+
215+
related_model_info = helpers.lookup_class_typeinfo(api, related_model)
214216
is_nullable = self.fields_context.get_field_nullability(field, method)
215217
foreign_key_set_type = helpers.get_private_descriptor_type(foreign_key_info,
216218
'_pyi_private_set_type',
217219
is_nullable=is_nullable)
218220
model_set_type = helpers.convert_any_to_type(foreign_key_set_type,
219221
Instance(related_model_info, []))
222+
220223
expected_types[field_name] = model_set_type
221224

222225
elif isinstance(field, GenericForeignKey):

mypy_django_plugin/lib/helpers.py

Lines changed: 4 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,17 +1,13 @@
11
from collections import OrderedDict
2-
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Set, Union
2+
from typing import Any, Dict, List, Optional, Set, TYPE_CHECKING, Union
33

44
from mypy import checker
55
from mypy.checker import TypeChecker
66
from mypy.mro import calculate_mro
7-
from mypy.nodes import (
8-
GDEF, MDEF, Block, ClassDef, Expression, MemberExpr, MypyFile, NameExpr, StrExpr, SymbolNode, SymbolTable,
9-
SymbolTableNode, TypeInfo, Var,
10-
)
7+
from mypy.nodes import (Block, ClassDef, Expression, GDEF, MDEF, MemberExpr, MypyFile, NameExpr, StrExpr, SymbolNode,
8+
SymbolTable, SymbolTableNode, TypeInfo, Var)
119
from mypy.plugin import CheckerPluginInterface, FunctionContext, MethodContext
12-
from mypy.types import AnyType, Instance, NoneTyp, TupleType
13-
from mypy.types import Type as MypyType
14-
from mypy.types import TypedDictType, TypeOfAny, UnionType
10+
from mypy.types import AnyType, Instance, NoneTyp, TupleType, Type as MypyType, TypeOfAny, TypedDictType, UnionType
1511

1612
if TYPE_CHECKING:
1713
from mypy_django_plugin.django.context import DjangoContext

mypy_django_plugin/transformers/fields.py

Lines changed: 11 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -39,15 +39,20 @@ def fill_descriptor_types_for_related_field(ctx: FunctionContext, django_context
3939
return AnyType(TypeOfAny.from_error)
4040

4141
assert isinstance(current_field, RelatedField)
42-
referred_to_typeinfo = helpers.lookup_class_typeinfo(ctx.api, current_field.related_model)
43-
referred_to_type = Instance(referred_to_typeinfo, [])
42+
43+
related_model = related_model_to_set = current_field.related_model
44+
if related_model_to_set._meta.proxy_for_model:
45+
related_model_to_set = related_model._meta.proxy_for_model
46+
47+
related_model_info = helpers.lookup_class_typeinfo(ctx.api, related_model)
48+
related_model_to_set_info = helpers.lookup_class_typeinfo(ctx.api, related_model_to_set)
4449

4550
default_related_field_type = set_descriptor_types_for_field(ctx)
4651
# replace Any with referred_to_type
47-
args = []
48-
for default_arg in default_related_field_type.args:
49-
args.append(helpers.convert_any_to_type(default_arg, referred_to_type))
50-
52+
args = [
53+
helpers.convert_any_to_type(default_related_field_type.args[0], Instance(related_model_to_set_info, [])),
54+
helpers.convert_any_to_type(default_related_field_type.args[1], Instance(related_model_info, [])),
55+
]
5156
return helpers.reparametrize_instance(default_related_field_type, new_args=args)
5257

5358

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,21 @@
1+
- case: foreign_key_to_proxy_model_accepts_first_non_proxy_model
2+
main: |
3+
from myapp.models import Blog, Publisher, PublisherProxy
4+
Blog(publisher=Publisher())
5+
Blog.objects.create(publisher=Publisher())
6+
Blog().publisher = Publisher()
7+
reveal_type(Blog().publisher) # N: Revealed type is 'myapp.models.PublisherProxy*'
8+
installed_apps:
9+
- myapp
10+
files:
11+
- path: myapp/__init__.py
12+
- path: myapp/models.py
13+
content: |
14+
from django.db import models
15+
class Publisher(models.Model):
16+
pass
17+
class PublisherProxy(Publisher):
18+
class Meta:
19+
proxy = True
20+
class Blog(models.Model):
21+
publisher = models.ForeignKey(to=PublisherProxy, on_delete=models.CASCADE)

0 commit comments

Comments
 (0)