Skip to content

Commit f6457d1

Browse files
Add get_models_foreign_key helper
1 parent e933c5e commit f6457d1

File tree

2 files changed

+35
-32
lines changed

2 files changed

+35
-32
lines changed

mypy_django_plugin/django/context.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -118,6 +118,11 @@ def get_model_fields(self, model_cls: Type[Model]) -> Iterator["Field[Any, Any]"
118118
if isinstance(field, Field):
119119
yield field
120120

121+
def get_model_foreign_keys(self, model_cls: Type[Model]) -> Iterator["ForeignKey[Any, Any]"]:
122+
for field in model_cls._meta.get_fields():
123+
if isinstance(field, ForeignKey):
124+
yield field
125+
121126
def get_model_relations(self, model_cls: Type[Model]) -> Iterator[ForeignObjectRel]:
122127
for field in model_cls._meta.get_fields():
123128
if isinstance(field, ForeignObjectRel):

mypy_django_plugin/transformers/models.py

Lines changed: 30 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,6 @@
22

33
from django.db.models import Manager, Model
44
from django.db.models.fields import DateField, DateTimeField, Field
5-
from django.db.models.fields.related import ForeignKey
65
from django.db.models.fields.reverse_related import ForeignObjectRel, OneToOneRel
76
from mypy.checker import TypeChecker
87
from mypy.nodes import ARG_STAR2, Argument, AssignmentStmt, CallExpr, Context, NameExpr, TypeInfo, Var
@@ -234,41 +233,40 @@ def run_with_model_cls(self, model_cls: Type[Model]) -> None:
234233

235234
class AddRelatedModelsId(ModelClassInitializer):
236235
def run_with_model_cls(self, model_cls: Type[Model]) -> None:
237-
for field in model_cls._meta.get_fields():
238-
if isinstance(field, ForeignKey):
239-
related_model_cls = self.django_context.get_field_related_model_cls(field)
240-
if related_model_cls is None:
241-
error_context: Context = self.ctx.cls
242-
field_sym = self.ctx.cls.info.get(field.name)
243-
if field_sym is not None and field_sym.node is not None:
244-
error_context = field_sym.node
245-
self.api.fail(
246-
f"Cannot find model {field.related_model!r} referenced in field {field.name!r}",
247-
ctx=error_context,
248-
)
249-
self.add_new_node_to_model_class(field.attname, AnyType(TypeOfAny.explicit))
250-
continue
236+
for field in self.django_context.get_model_foreign_keys(model_cls):
237+
related_model_cls = self.django_context.get_field_related_model_cls(field)
238+
if related_model_cls is None:
239+
error_context: Context = self.ctx.cls
240+
field_sym = self.ctx.cls.info.get(field.name)
241+
if field_sym is not None and field_sym.node is not None:
242+
error_context = field_sym.node
243+
self.api.fail(
244+
f"Cannot find model {field.related_model!r} referenced in field {field.name!r}",
245+
ctx=error_context,
246+
)
247+
self.add_new_node_to_model_class(field.attname, AnyType(TypeOfAny.explicit))
248+
continue
251249

252-
if related_model_cls._meta.abstract:
253-
continue
250+
if related_model_cls._meta.abstract:
251+
continue
254252

255-
rel_target_field = self.django_context.get_related_target_field(related_model_cls, field)
256-
if not rel_target_field:
257-
continue
253+
rel_target_field = self.django_context.get_related_target_field(related_model_cls, field)
254+
if not rel_target_field:
255+
continue
258256

259-
try:
260-
field_info = self.lookup_class_typeinfo_or_incomplete_defn_error(rel_target_field.__class__)
261-
except helpers.IncompleteDefnException as exc:
262-
if not self.api.final_iteration:
263-
raise exc
264-
else:
265-
continue
257+
try:
258+
field_info = self.lookup_class_typeinfo_or_incomplete_defn_error(rel_target_field.__class__)
259+
except helpers.IncompleteDefnException as exc:
260+
if not self.api.final_iteration:
261+
raise exc
262+
else:
263+
continue
266264

267-
is_nullable = self.django_context.get_field_nullability(field, None)
268-
set_type, get_type = get_field_descriptor_types(
269-
field_info, is_set_nullable=is_nullable, is_get_nullable=is_nullable
270-
)
271-
self.add_new_node_to_model_class(field.attname, Instance(field_info, [set_type, get_type]))
265+
is_nullable = self.django_context.get_field_nullability(field, None)
266+
set_type, get_type = get_field_descriptor_types(
267+
field_info, is_set_nullable=is_nullable, is_get_nullable=is_nullable
268+
)
269+
self.add_new_node_to_model_class(field.attname, Instance(field_info, [set_type, get_type]))
272270

273271

274272
class AddManagers(ModelClassInitializer):

0 commit comments

Comments
 (0)