4
4
5
5
from mypy .mro import calculate_mro
6
6
from mypy .nodes import (
7
- AssignmentStmt , ClassDef , Expression , ImportedName , Lvalue , MypyFile , NameExpr , SymbolNode , TypeInfo ,
8
- SymbolTable , SymbolTableNode , Block , GDEF , MDEF , Var )
9
- from mypy .plugin import FunctionContext , MethodContext
7
+ GDEF , MDEF , AssignmentStmt , Block , CallExpr , ClassDef , Expression , ImportedName , Lvalue , MypyFile , NameExpr ,
8
+ SymbolNode , SymbolTable , SymbolTableNode , TypeInfo , Var ,
9
+ )
10
+ from mypy .plugin import CheckerPluginInterface , FunctionContext , MethodContext
10
11
from mypy .types import (
11
- AnyType , Instance , NoneTyp , Type , TypeOfAny , TypeVarType , UnionType ,
12
- TupleType , TypedDictType )
12
+ AnyType , Instance , NoneTyp , TupleType , Type , TypedDictType , TypeOfAny , TypeVarType , UnionType ,
13
+ )
13
14
14
15
if typing .TYPE_CHECKING :
15
16
from mypy .checker import TypeChecker
@@ -216,6 +217,7 @@ def extract_field_setter_type(tp: Instance) -> Optional[Type]:
216
217
217
218
218
219
def extract_field_getter_type (tp : Type ) -> Optional [Type ]:
220
+ """ Extract return type of __get__ of subclass of Field"""
219
221
if not isinstance (tp , Instance ):
220
222
return None
221
223
if tp .type .has_base (FIELD_FULLNAME ):
@@ -226,13 +228,12 @@ def extract_field_getter_type(tp: Type) -> Optional[Type]:
226
228
return None
227
229
228
230
229
- def get_django_metadata (model : TypeInfo ) -> Dict [str , typing .Any ]:
230
- return model .metadata .setdefault ('django' , {})
231
+ def get_django_metadata (model_info : TypeInfo ) -> Dict [str , typing .Any ]:
232
+ return model_info .metadata .setdefault ('django' , {})
231
233
232
234
233
235
def get_related_field_primary_key_names (base_model : TypeInfo ) -> typing .List [str ]:
234
- django_metadata = get_django_metadata (base_model )
235
- return django_metadata .setdefault ('related_field_primary_keys' , [])
236
+ return get_django_metadata (base_model ).setdefault ('related_field_primary_keys' , [])
236
237
237
238
238
239
def get_fields_metadata (model : TypeInfo ) -> Dict [str , typing .Any ]:
@@ -243,6 +244,10 @@ def get_lookups_metadata(model: TypeInfo) -> Dict[str, typing.Any]:
243
244
return get_django_metadata (model ).setdefault ('lookups' , {})
244
245
245
246
247
+ def get_related_managers_metadata (model : TypeInfo ) -> Dict [str , typing .Any ]:
248
+ return get_django_metadata (model ).setdefault ('related_managers' , {})
249
+
250
+
246
251
def extract_explicit_set_type_of_model_primary_key (model : TypeInfo ) -> Optional [Type ]:
247
252
"""
248
253
If field with primary_key=True is set on the model, extract its __set__ type.
@@ -310,7 +315,7 @@ def is_field_nullable(model: TypeInfo, field_name: str) -> bool:
310
315
return get_fields_metadata (model ).get (field_name , {}).get ('null' , False )
311
316
312
317
313
- def is_foreign_key (t : Type ) -> bool :
318
+ def is_foreign_key_like (t : Type ) -> bool :
314
319
if not isinstance (t , Instance ):
315
320
return False
316
321
return has_any_of_bases (t .type , (FOREIGN_KEY_FULLNAME , ONETOONE_FIELD_FULLNAME ))
@@ -366,13 +371,14 @@ def make_named_tuple(api: 'TypeChecker', fields: 'OrderedDict[str, Type]', name:
366
371
return TupleType (list (fields .values ()), fallback = fallback )
367
372
368
373
369
- def make_typeddict (api : 'TypeChecker' , fields : 'OrderedDict[str, Type]' , required_keys : typing .Set [str ]) -> Type :
374
+ def make_typeddict (api : CheckerPluginInterface , fields : 'OrderedDict[str, Type]' ,
375
+ required_keys : typing .Set [str ]) -> TypedDictType :
370
376
object_type = api .named_generic_type ('mypy_extensions._TypedDict' , [])
371
377
typed_dict_type = TypedDictType (fields , required_keys = required_keys , fallback = object_type )
372
378
return typed_dict_type
373
379
374
380
375
- def make_tuple (api : 'TypeChecker' , fields : typing .List [Type ]) -> Type :
381
+ def make_tuple (api : 'TypeChecker' , fields : typing .List [Type ]) -> TupleType :
376
382
implicit_any = AnyType (TypeOfAny .special_form )
377
383
fallback = api .named_generic_type ('builtins.tuple' , [implicit_any ])
378
384
return TupleType (fields , fallback = fallback )
@@ -386,3 +392,52 @@ def get_private_descriptor_type(type_info: TypeInfo, private_field_name: str, is
386
392
descriptor_type = make_optional (descriptor_type )
387
393
return descriptor_type
388
394
return AnyType (TypeOfAny .unannotated )
395
+
396
+
397
+ def iter_over_classdefs (module_file : MypyFile ) -> typing .Iterator [ClassDef ]:
398
+ for defn in module_file .defs :
399
+ if isinstance (defn , ClassDef ):
400
+ yield defn
401
+
402
+
403
+ def iter_call_assignments (klass : ClassDef ) -> typing .Iterator [typing .Tuple [Lvalue , CallExpr ]]:
404
+ for lvalue , rvalue in iter_over_assignments (klass ):
405
+ if isinstance (rvalue , CallExpr ):
406
+ yield lvalue , rvalue
407
+
408
+
409
+ def get_related_manager_type_from_metadata (model_info : TypeInfo , related_manager_name : str ,
410
+ api : CheckerPluginInterface ) -> Optional [Instance ]:
411
+ related_manager_metadata = get_related_managers_metadata (model_info )
412
+ if not related_manager_metadata :
413
+ return None
414
+
415
+ if related_manager_name not in related_manager_metadata :
416
+ return None
417
+
418
+ manager_class_name = related_manager_metadata [related_manager_name ]['manager' ]
419
+ of = related_manager_metadata [related_manager_name ]['of' ]
420
+ of_types = []
421
+ for of_type_name in of :
422
+ if of_type_name == 'any' :
423
+ of_types .append (AnyType (TypeOfAny .implementation_artifact ))
424
+ else :
425
+ try :
426
+ of_type = api .named_generic_type (of_type_name , [])
427
+ except AssertionError :
428
+ # Internal error: attempted lookup of unknown name
429
+ of_type = AnyType (TypeOfAny .implementation_artifact )
430
+
431
+ of_types .append (of_type )
432
+
433
+ return api .named_generic_type (manager_class_name , of_types )
434
+
435
+
436
+ def get_primary_key_field_name (model_info : TypeInfo ) -> Optional [str ]:
437
+ for base in model_info .mro :
438
+ fields = get_fields_metadata (base )
439
+ for field_name , field_props in fields .items ():
440
+ is_primary_key = field_props .get ('primary_key' , False )
441
+ if is_primary_key :
442
+ return field_name
443
+ return None
0 commit comments