Skip to content

Make import cycles more predictable by prioritizing different import forms #1736

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 9 commits into from
Jun 27, 2016
109 changes: 87 additions & 22 deletions mypy/build.py
Original file line number Diff line number Diff line change
Expand Up @@ -327,12 +327,20 @@ def default_lib_path(data_dir: str, pyversion: Tuple[int, int],
('data_json', str), # path of <id>.data.json
('suppressed', List[str]), # dependencies that weren't imported
('flags', Optional[List[str]]), # build flags
('dep_prios', List[int]),
])
# NOTE: dependencies + suppressed == all unreachable imports;
# NOTE: dependencies + suppressed == all reachable imports;
# suppressed contains those reachable imports that were prevented by
# --silent-imports or simply not found.


# Priorities used for imports. (Here, top-level includes inside a class.)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe mention what priorities are used for and why.

PRI_HIGH = 5 # top-level "from X import blah"
PRI_MED = 10 # top-level "import X"
PRI_LOW = 20 # either form inside a function
PRI_ALL = 99 # include all priorities


class BuildManager:
"""This class holds shared state for building a mypy program.

Expand Down Expand Up @@ -393,12 +401,13 @@ def __init__(self, data_dir: str,
self.missing_modules = set() # type: Set[str]

def all_imported_modules_in_file(self,
file: MypyFile) -> List[Tuple[str, int]]:
file: MypyFile) -> List[Tuple[int, str, int]]:
"""Find all reachable import statements in a file.

Return list of tuples (module id, import line number) for all modules
imported in file.
Return list of tuples (priority, module id, import line number)
for all modules imported in file; lower numbers == higher priority.
"""

def correct_rel_imp(imp: Union[ImportFrom, ImportAll]) -> str:
"""Function to correct for relative imports."""
file_id = file.fullname()
Expand All @@ -413,12 +422,13 @@ def correct_rel_imp(imp: Union[ImportFrom, ImportAll]) -> str:

return new_id

res = [] # type: List[Tuple[str, int]]
res = [] # type: List[Tuple[int, str, int]]
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What if the same module is present multiple times here with different priorities? Should we choose the highest priority?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, and that's what happens later, at the call site (line 1280 below once I've pushed a new version).

for imp in file.imports:
if not imp.is_unreachable:
if isinstance(imp, Import):
pri = PRI_MED if imp.is_top_level else PRI_LOW
for id, _ in imp.ids:
res.append((id, imp.line))
res.append((pri, id, imp.line))
elif isinstance(imp, ImportFrom):
cur_id = correct_rel_imp(imp)
pos = len(res)
Expand All @@ -427,7 +437,7 @@ def correct_rel_imp(imp: Union[ImportFrom, ImportAll]) -> str:
for name, __ in imp.names:
sub_id = cur_id + '.' + name
if self.is_module(sub_id):
res.append((sub_id, imp.line))
res.append((0, sub_id, imp.line))
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Use symbolic constant for priority here? Why is the priority zero here?

else:
all_are_submodules = False
# If all imported names are submodules, don't add
Expand All @@ -436,9 +446,12 @@ def correct_rel_imp(imp: Union[ImportFrom, ImportAll]) -> str:
# cur_id is also a dependency, and we should
# insert it *before* any submodules.
if not all_are_submodules:
res.insert(pos, ((cur_id, imp.line)))
pri = PRI_HIGH if imp.is_top_level else PRI_LOW
res.insert(pos, ((pri, cur_id, imp.line)))
elif isinstance(imp, ImportAll):
res.append((correct_rel_imp(imp), imp.line))
pri = PRI_HIGH if imp.is_top_level else PRI_LOW
res.append((pri, correct_rel_imp(imp), imp.line))

return res

def is_module(self, id: str) -> bool:
Expand Down Expand Up @@ -754,16 +767,18 @@ def find_cache_meta(id: str, path: str, manager: BuildManager) -> Optional[Cache
data_json,
meta.get('suppressed', []),
meta.get('flags'),
meta.get('dep_prios', []),
)
if (m.id != id or m.path != path or
m.mtime is None or m.size is None or
m.dependencies is None or m.data_mtime is None):
return None

# Metadata generated by older mypy version and no flags were saved
if m.flags is None:
# Ignore cache if generated by an older mypy version.
if m.flags is None or len(m.dependencies) != len(m.dep_prios):
return None

# Ignore cache if (relevant) flags aren't the same.
cached_flags = select_flags_affecting_cache(m.flags)
current_flags = select_flags_affecting_cache(manager.flags)
if cached_flags != current_flags:
Expand Down Expand Up @@ -802,6 +817,7 @@ def random_string():

def write_cache(id: str, path: str, tree: MypyFile,
dependencies: List[str], suppressed: List[str],
dep_prios: List[int],
manager: BuildManager) -> None:
"""Write cache files for a module.

Expand All @@ -811,6 +827,7 @@ def write_cache(id: str, path: str, tree: MypyFile,
tree: the fully checked module data
dependencies: module IDs on which this module depends
suppressed: module IDs which were suppressed as dependencies
dep_prios: priorities (parallel array to dependencies)
manager: the build manager (for pyversion, log/trace)
"""
path = os.path.abspath(path)
Expand Down Expand Up @@ -840,6 +857,7 @@ def write_cache(id: str, path: str, tree: MypyFile,
'dependencies': dependencies,
'suppressed': suppressed,
'flags': manager.flags,
'dep_prios': dep_prios,
}
with open(meta_json_tmp, 'w') as f:
json.dump(meta, f, sort_keys=True)
Expand Down Expand Up @@ -1012,6 +1030,7 @@ class State:
tree = None # type: Optional[MypyFile]
dependencies = None # type: List[str]
suppressed = None # type: List[str] # Suppressed/missing dependencies
priorities = None # type: Dict[str, int]

# Map each dependency to the line number where it is first imported
dep_line_map = None # type: Dict[str, int]
Expand Down Expand Up @@ -1114,6 +1133,9 @@ def __init__(self,
# compare them to the originals later.
self.dependencies = list(self.meta.dependencies)
self.suppressed = list(self.meta.suppressed)
assert len(self.meta.dependencies) == len(self.meta.dep_prios)
self.priorities = {id: pri
for id, pri in zip(self.meta.dependencies, self.meta.dep_prios)}
self.dep_line_map = {}
else:
# Parse the file (and then some) to get the dependencies.
Expand Down Expand Up @@ -1249,8 +1271,10 @@ def parse_file(self) -> None:
# Also keep track of each dependency's source line.
dependencies = []
suppressed = []
priorities = {} # type: Dict[str, int] # id -> priority
dep_line_map = {} # type: Dict[str, int] # id -> line
for id, line in manager.all_imported_modules_in_file(self.tree):
for pri, id, line in manager.all_imported_modules_in_file(self.tree):
priorities[id] = min(pri, priorities.get(id, PRI_LOW))
if id == self.id:
continue
# Omit missing modules, as otherwise we could not type-check
Expand Down Expand Up @@ -1281,6 +1305,7 @@ def parse_file(self) -> None:
# for differences (e.g. --silent-imports).
self.dependencies = dependencies
self.suppressed = suppressed
self.priorities = priorities
self.dep_line_map = dep_line_map
self.check_blockers()

Expand Down Expand Up @@ -1320,8 +1345,10 @@ def type_check(self) -> None:

def write_cache(self) -> None:
if self.path and INCREMENTAL in self.manager.flags and not self.manager.errors.is_errors():
dep_prios = [self.priorities.get(dep, PRI_HIGH) for dep in self.dependencies]
write_cache(self.id, self.path, self.tree,
list(self.dependencies), list(self.suppressed),
dep_prios,
self.manager)


Expand Down Expand Up @@ -1388,10 +1415,9 @@ def process_graph(graph: Graph, manager: BuildManager) -> None:
# dependencies) to roots (those from which everything else can be
# reached).
for ascc in sccs:
# Sort the SCC's nodes in *reverse* order or encounter.
# This is a heuristic for handling import cycles.
# Order the SCC's nodes using a heuristic.
# Note that ascc is a set, and scc is a list.
scc = sorted(ascc, key=lambda id: -graph[id].order)
scc = order_ascc(graph, ascc)
# If builtins is in the list, move it last. (This is a bit of
# a hack, but it's necessary because the builtins module is
# part of a small cycle involving at least {builtins, abc,
Expand All @@ -1400,6 +1426,12 @@ def process_graph(graph: Graph, manager: BuildManager) -> None:
if 'builtins' in ascc:
scc.remove('builtins')
scc.append('builtins')
if manager.flags.count(VERBOSE) >= 2:
for id in scc:
manager.trace("Priorities for %s:" % id,
" ".join("%s:%d" % (x, graph[id].priorities[x])
for x in graph[id].dependencies
if x in ascc and x in graph[id].priorities))
# Because the SCCs are presented in topological sort order, we
# don't need to look at dependencies recursively for staleness
# -- the immediate dependencies are sufficient.
Expand All @@ -1426,7 +1458,7 @@ def process_graph(graph: Graph, manager: BuildManager) -> None:
# cache file is newer than any scc node's cache file.
oldest_in_scc = min(graph[id].meta.data_mtime for id in scc)
newest_in_deps = 0 if not deps else max(graph[dep].meta.data_mtime for dep in deps)
if manager.flags.count(VERBOSE) >= 2: # Dump all mtimes for extreme debugging.
if manager.flags.count(VERBOSE) >= 3: # Dump all mtimes for extreme debugging.
all_ids = sorted(ascc | deps, key=lambda id: graph[id].meta.data_mtime)
for id in all_ids:
if id in scc:
Expand Down Expand Up @@ -1466,6 +1498,27 @@ def process_graph(graph: Graph, manager: BuildManager) -> None:
process_stale_scc(graph, scc)


def order_ascc(graph: Graph, ascc: AbstractSet[str], pri_max: int = PRI_ALL) -> List[str]:
"""Come up with the ideal processing order within an SCC."""
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Explain what this algorithm does in more detail. Describe the time complexity. It would be nice to have unit tests for this, as the code looks a little tricky.

if len(ascc) == 1:
return [s for s in ascc]
pri_spread = set()
for id in ascc:
state = graph[id]
for dep in state.dependencies:
if dep in ascc:
pri = state.priorities.get(dep, PRI_HIGH)
if pri < pri_max:
pri_spread.add(pri)
if len(pri_spread) == 1:
# Filtered dependencies are homogeneous -- order by global order.
return sorted(ascc, key=lambda id: -graph[id].order)
pri_max = max(pri_spread)
sccs = sorted_components(graph, ascc, pri_max)
# The recursion is bounded by the len(pri_spread) check above.
return [s for ss in sccs for s in order_ascc(graph, ss, pri_max)]


def process_fresh_scc(graph: Graph, scc: List[str]) -> None:
"""Process the modules in one SCC from their cached data."""
for id in scc:
Expand Down Expand Up @@ -1497,7 +1550,9 @@ def process_stale_scc(graph: Graph, scc: List[str]) -> None:
graph[id].write_cache()


def sorted_components(graph: Graph) -> List[AbstractSet[str]]:
def sorted_components(graph: Graph,
vertices: Optional[AbstractSet[str]] = None,
pri_max: int = PRI_ALL) -> List[AbstractSet[str]]:
"""Return the graph's SCCs, topologically sorted by dependencies.

The sort order is from leaves (nodes without dependencies) to
Expand All @@ -1507,17 +1562,17 @@ def sorted_components(graph: Graph) -> List[AbstractSet[str]]:
dependencies that aren't present in graph.keys() are ignored.
"""
# Compute SCCs.
vertices = set(graph)
edges = {id: [dep for dep in st.dependencies if dep in graph]
for id, st in graph.items()}
if vertices is None:
vertices = set(graph)
edges = {id: deps_filtered(graph, vertices, id, pri_max) for id in vertices}
sccs = list(strongly_connected_components(vertices, edges))
# Topsort.
sccsmap = {id: frozenset(scc) for scc in sccs for id in scc}
data = {} # type: Dict[AbstractSet[str], Set[AbstractSet[str]]]
for scc in sccs:
deps = set() # type: Set[AbstractSet[str]]
for id in scc:
deps.update(sccsmap[x] for x in graph[id].dependencies if x in graph)
deps.update(sccsmap[x] for x in deps_filtered(graph, vertices, id, pri_max))
data[frozenset(scc)] = deps
res = []
for ready in topsort(data):
Expand All @@ -1534,7 +1589,17 @@ def sorted_components(graph: Graph) -> List[AbstractSet[str]]:
return res


def strongly_connected_components(vertices: Set[str],
def deps_filtered(graph: Graph, vertices: AbstractSet[str], id: str, pri_max: int) -> List[str]:
"""Filter dependencies for id with pri < pri_max."""
if id not in vertices:
return []
state = graph[id]
return [dep
for dep in state.dependencies
if dep in vertices and state.priorities.get(dep, PRI_HIGH) < pri_max]


def strongly_connected_components(vertices: AbstractSet[str],
edges: Dict[str, List[str]]) -> Iterator[Set[str]]:
"""Compute Strongly Connected Components of a directed graph.

Expand Down
1 change: 1 addition & 0 deletions mypy/nodes.py
Original file line number Diff line number Diff line change
Expand Up @@ -224,6 +224,7 @@ def deserialize(cls, data: JsonDict) -> 'MypyFile':
class ImportBase(Node):
"""Base class for all import statements."""
is_unreachable = False
is_top_level = False # Set by semanal.FirstPass
# If an import replaces existing definitions, we construct dummy assignment
# statements that assign the imported names to the names in the current scope,
# for type checking purposes. Example:
Expand Down
5 changes: 5 additions & 0 deletions mypy/semanal.py
Original file line number Diff line number Diff line change
Expand Up @@ -2414,12 +2414,14 @@ def visit_import_from(self, node: ImportFrom) -> None:
# We can't bind module names during the first pass, as the target module might be
# unprocessed. However, we add dummy unbound imported names to the symbol table so
# that we at least know that the name refers to a module.
node.is_top_level = True
for name, as_name in node.names:
imported_name = as_name or name
if imported_name not in self.sem.globals:
self.sem.add_symbol(imported_name, SymbolTableNode(UNBOUND_IMPORTED, None), node)

def visit_import(self, node: Import) -> None:
node.is_top_level = True
# This is similar to visit_import_from -- see the comment there.
for id, as_id in node.ids:
imported_id = as_id or id
Expand All @@ -2429,6 +2431,9 @@ def visit_import(self, node: Import) -> None:
# If the previous symbol is a variable, this should take precedence.
self.sem.globals[imported_id] = SymbolTableNode(UNBOUND_IMPORTED, None)

def visit_import_all(self, node: ImportAll) -> None:
node.is_top_level = True

def visit_while_stmt(self, s: WhileStmt) -> None:
s.body.accept(self)
if s.else_body:
Expand Down
32 changes: 32 additions & 0 deletions test-data/unit/check-modules.test
Original file line number Diff line number Diff line change
Expand Up @@ -811,3 +811,35 @@ from a import x
[file a/__init__.py]
x = 0
[out]


-- Test stability under import cycles
-- ----------------------------------

-- The two tests are identical except one main has 'import x' and the other 'import y'
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggest adding more motivation for this test. In particular, when would one of these fail.


[case testImportCycleStability1]
import x
[file x.py]
class Base:
attr = 'x'
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe use a trickier initializer here so that better type inference is less likely to make the order irrelevant.

def foo():
import y
[file y.py]
import x
class Sub(x.Base):
attr = x.Base.attr
[out]

[case testImportCycleStability2]
import y
[file x.py]
class Base:
attr = 'x'
def foo():
import y
[file y.py]
import x
class Sub(x.Base):
attr = x.Base.attr
[out]