11"""Collect type information."""
22
3- import builtins
43import importlib
54import json
65import logging
@@ -226,27 +225,6 @@ def _is_type(value):
226225 return is_type
227226
228227
229- def _builtin_types ():
230- """Return known imports for all builtins (in the current runtime).
231-
232- Returns
233- -------
234- known_imports : dict[str, KnownImport]
235- """
236- known_builtins = set (dir (builtins ))
237-
238- known_imports = {}
239- for name in known_builtins :
240- if name .startswith ("_" ):
241- continue
242- value = getattr (builtins , name )
243- if not _is_type (value ):
244- continue
245- known_imports [name ] = KnownImport (builtin_name = name )
246-
247- return known_imports
248-
249-
250228def _runtime_types_in_module (module_name ):
251229 module = importlib .import_module (module_name )
252230 types = {}
@@ -277,18 +255,20 @@ def common_known_types():
277255 Examples
278256 --------
279257 >>> types = common_known_types()
280- >>> types["str"]
281- <KnownImport str (builtin) >
282- >>> types["Iterable"]
283- <KnownImport 'from collections.abc import Iterable'>
258+ >>> types["builtins. str"]
259+ <KnownImport 'from builtins import str' >
260+ >>> types["typing. Iterable"]
261+ <KnownImport 'from typing import Iterable'>
284262 >>> types["collections.abc.Iterable"]
285263 <KnownImport 'from collections.abc import Iterable'>
286264 """
287- known_imports = _builtin_types ()
288- known_imports |= _runtime_types_in_module ("typing" )
289- # Overrides containers from typing
290- known_imports |= _runtime_types_in_module ("collections.abc" )
291- return known_imports
265+ from ._stdlib_types import stdlib_types
266+
267+ types = {
268+ f"{ module } .{ type_name } " : KnownImport (import_path = module , import_name = type_name )
269+ for module , type_name in stdlib_types
270+ }
271+ return types
292272
293273
294274class TypeCollector (cst .CSTVisitor ):
@@ -334,7 +314,7 @@ def collect(cls, file):
334314
335315 Returns
336316 -------
337- collected : dict[str, KnownImport]
317+ collected_types : dict[str, KnownImport]
338318 """
339319 file = Path (file )
340320 with file .open ("r" ) as fo :
@@ -343,7 +323,7 @@ def collect(cls, file):
343323 tree = cst .parse_module (source )
344324 collector = cls (module_name = module_name_from_path (file ))
345325 tree .visit (collector )
346- return collector .known_imports
326+ return collector .collected_types
347327
348328 def __init__ (self , * , module_name ):
349329 """Initialize type collector.
@@ -354,7 +334,7 @@ def __init__(self, *, module_name):
354334 """
355335 self .module_name = module_name
356336 self ._stack = []
357- self .known_imports = {}
337+ self .collected_types = {}
358338
359339 def visit_ClassDef (self , node : cst .ClassDef ) -> bool :
360340 self ._stack .append (node .name .value )
@@ -396,9 +376,104 @@ def _collect_type_annotation(self, stack):
396376 stack : Iterable[str]
397377 A list of names that form the path to the collected type.
398378 """
399- qualname = "." .join ([self .module_name , * stack ])
400379 known_import = KnownImport (import_path = self .module_name , import_name = stack [0 ])
401- self .known_imports [qualname ] = known_import
380+
381+ qualname = f"{ self .module_name } .{ '.' .join (stack )} "
382+ scoped_name = f"{ self .module_name } :{ '.' .join (stack )} "
383+ self .collected_types [qualname ] = known_import
384+ self .collected_types [scoped_name ] = known_import
385+
386+
387+ class StubTypeCollector (TypeCollector ):
388+
389+ def __init__ (self , * , module_name ):
390+ """Initialize type collector.
391+
392+ Parameters
393+ ----------
394+ module_name : str
395+ """
396+ super ().__init__ (module_name = module_name )
397+ self .collected_types = set ()
398+ self .dunder_all = set ()
399+
400+ @classmethod
401+ def collect (cls , file ):
402+ """Collect importable type annotations in given file.
403+
404+ Parameters
405+ ----------
406+ file : Path
407+
408+ Returns
409+ -------
410+ collected_types : dict[str, KnownImport]
411+ """
412+ file = Path (file )
413+ with file .open ("r" ) as fo :
414+ source = fo .read ()
415+
416+ tree = cst .parse_module (source )
417+ collector = cls (module_name = module_name_from_path (file ))
418+ tree .visit (collector )
419+ return collector .collected_types , collector .dunder_all
420+
421+ def visit_ImportFrom (self , node ):
422+ # https://typing.python.org/en/latest/spec/distributing.html#import-conventions
423+
424+ if cstm .matches (node , cstm .ImportFrom (names = cstm .ImportStar ())):
425+ module_names = cstm .findall (node .module , cstm .Name ())
426+ module = "_" .join (name .value for name in module_names )
427+ stack = [* self ._stack , f"<Reference: { module } .*>" ]
428+ self ._collect_type_annotation (stack )
429+
430+ names = cstm .findall (node , cstm .AsName ())
431+ for name in names :
432+ if cstm .matches (name , cstm .AsName (name = cstm .Name ())):
433+ value = name .name .value
434+ assert value
435+ if value == "__all__" :
436+ continue
437+
438+ stack = [* self ._stack , value ]
439+ self ._collect_type_annotation (stack )
440+
441+ def visit_AugAssign (self , node ):
442+ is_add_assign_to_dunder_all = cstm .matches (
443+ node ,
444+ cstm .AugAssign (
445+ target = cstm .Name (value = "__all__" ), operator = cstm .AddAssign ()
446+ ),
447+ )
448+ is_assign_list = cstm .matches (node .value , cstm .List ())
449+ if is_add_assign_to_dunder_all and is_assign_list :
450+ strings = cstm .findall (node .value , cstm .SimpleString ())
451+ for string in strings :
452+ self ._collect_dunder_all (string .value )
453+
454+ def visit_Assign (self , node ):
455+ is_assign_to_dunder_all = cstm .matches (
456+ node ,
457+ cstm .Assign (targets = [cstm .AssignTarget (target = cstm .Name (value = "__all__" ))]),
458+ )
459+ is_assign_list = cstm .matches (node .value , cstm .List ())
460+ if is_assign_to_dunder_all and is_assign_list :
461+ strings = cstm .findall (node .value , cstm .SimpleString ())
462+ for string in strings :
463+ self ._collect_dunder_all (string .value )
464+
465+ def _collect_type_annotation (self , stack ):
466+ """Collect an importable type annotation.
467+
468+ Parameters
469+ ----------
470+ stack : Iterable[str]
471+ A list of names that form the path to the collected type.
472+ """
473+ self .collected_types .add ((self .module_name , "." .join (stack )))
474+
475+ def _collect_dunder_all (self , value ):
476+ self .dunder_all .add ((self .module_name , value .strip ("'\" " )))
402477
403478
404479class TypeMatcher :
@@ -427,6 +502,7 @@ def __init__(
427502 types = None ,
428503 type_prefixes = None ,
429504 type_nicknames = None ,
505+ implicit_modules = ("collections.abc" , "typing" , "_typeshed" ),
430506 ):
431507 """
432508 Parameters
@@ -438,6 +514,7 @@ def __init__(
438514 self .types = types or common_known_types ()
439515 self .type_prefixes = type_prefixes or {}
440516 self .type_nicknames = type_nicknames or {}
517+ self .implicit_modules = implicit_modules
441518 self .successful_queries = 0
442519 self .unknown_qualnames = []
443520
@@ -492,20 +569,39 @@ def match(self, search_name):
492569 # Replace alias
493570 search_name = self .type_nicknames .get (search_name , search_name )
494571
495- if type_origin is None and self .current_module :
496- # Try scope of current module
497- module_name = module_name_from_path (self .current_module )
498- try_qualname = f"{ module_name } .{ search_name } "
572+ if type_origin is None :
573+ # Try builtin
574+ try_qualname = f"builtins.{ search_name } "
499575 type_origin = self .types .get (try_qualname )
500576 if type_origin :
501577 type_name = search_name
502578
503579 if type_origin is None and search_name in self .types :
580+ # Direct match
504581 type_name = search_name
505582 type_origin = self .types [search_name ]
506583
584+ if type_origin is None and self .current_module :
585+ # Try scope of current module
586+ for sep in ["." , ":" ]:
587+ try_qualname = f"{ self .current_module } { sep } { search_name } "
588+ type_origin = self .types .get (try_qualname )
589+ if type_origin :
590+ type_name = search_name
591+ break
592+
593+ if type_origin is None and self .implicit_modules :
594+ # Try implicit modules
595+ for module in self .implicit_modules :
596+ try_qualname = f"{ module } .{ search_name } "
597+ type_origin = self .types .get (try_qualname )
598+ if type_origin :
599+ type_name = search_name
600+ break
601+
507602 if type_origin is None :
508- # Try a subset of the qualname (first 'a.b.c', then 'a.b' and 'a')
603+ # Try matching with module prefix,
604+ # try a subset of the qualname (first 'a.b.c', then 'a.b' and 'a')
509605 for partial_qualname in reversed (accumulate_qualname (search_name )):
510606 type_origin = self .type_prefixes .get (partial_qualname )
511607 if type_origin :
0 commit comments