Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
27 commits
Select commit Hold shift + click to select a range
1b3ff5e
rename infer_symbol_type to infer_symbol_public_type
carljm May 31, 2024
ffd589f
infer public type from all definitions
carljm May 31, 2024
b01cbe3
add TODO comment
carljm Jun 1, 2024
18a64d8
add TODO comment about union-of-all-defs
carljm Jun 1, 2024
a72293e
[red-knot] use reachable definitions in infer_expression_type
carljm Jun 1, 2024
55d7428
[red-knot] extract helper functions in inference tests
carljm Jun 1, 2024
ad223c7
remove need for db = case.db in each test
carljm Jun 1, 2024
76c566f
also extract write_to_path
carljm Jun 1, 2024
89452b9
use textwrap::dedent
carljm Jun 1, 2024
9f494ee
[red-knot] add if-statement support to FlowGraph
carljm Jun 1, 2024
07ab4cd
improve tests
carljm Jun 1, 2024
ebb67bf
remove redundant Branch node
carljm Jun 1, 2024
ac934bb
FlowGraph doesn't need to be pub(crate)
carljm Jun 1, 2024
b1f6153
Merge branch 'main' into cjm/cfg1
carljm Jun 3, 2024
b8535b7
review comments
carljm Jun 3, 2024
3dbf5d7
Merge branch 'cjm/cfg1' into cjm/cfg2
carljm Jun 3, 2024
d73db36
review comments
carljm Jun 3, 2024
6a3780e
Merge branch 'cjm/cfg2' into cjm/cfg3
carljm Jun 3, 2024
ec113a1
review comments
carljm Jun 3, 2024
3d52dfa
Merge branch 'main' into cjm/cfg2
carljm Jun 3, 2024
fda62af
Merge branch 'cjm/cfg2' into cjm/cfg3
carljm Jun 3, 2024
fe878d2
Merge branch 'cjm/cfg3' into cjm/cfg4
carljm Jun 3, 2024
f9312db
Merge branch 'main' into cjm/cfg4
carljm Jun 3, 2024
1d3048d
Merge branch 'main' into cjm/cfg4
carljm Jun 4, 2024
21c430d
only two preds per phi, source order visit
carljm Jun 4, 2024
d42fff9
code review comments
carljm Jun 4, 2024
a85871d
Merge branch 'main' into cjm/cfg4
carljm Jun 4, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
170 changes: 153 additions & 17 deletions crates/red_knot/src/symbols.rs
Original file line number Diff line number Diff line change
Expand Up @@ -184,6 +184,7 @@ pub(crate) enum Definition {
FunctionDef(TypedNodeKey<ast::StmtFunctionDef>),
Assignment(TypedNodeKey<ast::StmtAssign>),
AnnotatedAssignment(TypedNodeKey<ast::StmtAnnAssign>),
None,
// TODO with statements, except handlers, function args...
}

Expand Down Expand Up @@ -288,8 +289,8 @@ impl SymbolTable {
let flow_node_id = self.flow_graph.ast_to_flow[&node_key];
ReachableDefinitionsIterator {
table: self,
flow_node_id,
symbol_id,
pending: vec![flow_node_id],
}
}

Expand Down Expand Up @@ -545,25 +546,30 @@ where
#[derive(Debug)]
pub(crate) struct ReachableDefinitionsIterator<'a> {
table: &'a SymbolTable,
flow_node_id: FlowNodeId,
symbol_id: SymbolId,
pending: Vec<FlowNodeId>,
}

impl<'a> Iterator for ReachableDefinitionsIterator<'a> {
type Item = Definition;

fn next(&mut self) -> Option<Self::Item> {
loop {
match &self.table.flow_graph.flow_nodes_by_id[self.flow_node_id] {
FlowNode::Start => return None,
let flow_node_id = self.pending.pop()?;
match &self.table.flow_graph.flow_nodes_by_id[flow_node_id] {
FlowNode::Start => return Some(Definition::None),
FlowNode::Definition(def_node) => {
if def_node.symbol_id == self.symbol_id {
// we found a definition; previous definitions along this path are not
// reachable
self.flow_node_id = FlowGraph::start();
return Some(def_node.definition.clone());
}
self.flow_node_id = def_node.predecessor;
self.pending.push(def_node.predecessor);
}
FlowNode::Branch(branch_node) => {
self.pending.push(branch_node.predecessor);
}
FlowNode::Phi(phi_node) => {
self.pending.push(phi_node.first_predecessor);
self.pending.push(phi_node.second_predecessor);
}
}
}
Expand All @@ -579,15 +585,31 @@ struct FlowNodeId;
enum FlowNode {
Start,
Definition(DefinitionFlowNode),
Branch(BranchFlowNode),
Phi(PhiFlowNode),
}

/// A Definition node represents a point in control flow where a symbol is defined
#[derive(Debug)]
struct DefinitionFlowNode {
symbol_id: SymbolId,
definition: Definition,
predecessor: FlowNodeId,
}

/// A Branch node represents a branch in control flow
#[derive(Debug)]
struct BranchFlowNode {
predecessor: FlowNodeId,
}

/// A Phi node represents a join point where control flow paths come together
#[derive(Debug)]
struct PhiFlowNode {
first_predecessor: FlowNodeId,
second_predecessor: FlowNodeId,
}

#[derive(Debug, Default)]
struct FlowGraph {
flow_nodes_by_id: IndexVec<FlowNodeId, FlowNode>,
Expand Down Expand Up @@ -636,6 +658,10 @@ impl SymbolTableBuilder {
.add_or_update_symbol(self.cur_scope(), identifier, flags)
}

fn new_flow_node(&mut self, node: FlowNode) -> FlowNodeId {
self.table.flow_graph.flow_nodes_by_id.push(node)
}

fn add_or_update_symbol_with_def(
&mut self,
identifier: &str,
Expand All @@ -647,15 +673,11 @@ impl SymbolTableBuilder {
.entry(symbol_id)
.or_default()
.push(definition.clone());
let new_flow_node_id = self
.table
.flow_graph
.flow_nodes_by_id
.push(FlowNode::Definition(DefinitionFlowNode {
definition,
symbol_id,
predecessor: self.current_flow_node(),
}));
let new_flow_node_id = self.new_flow_node(FlowNode::Definition(DefinitionFlowNode {
definition,
symbol_id,
predecessor: self.current_flow_node(),
}));
self.set_current_flow_node(new_flow_node_id);
symbol_id
}
Expand Down Expand Up @@ -871,13 +893,127 @@ impl PreorderVisitor<'_> for SymbolTableBuilder {
ast::visitor::preorder::walk_stmt(self, stmt);
self.current_definition = None;
}
ast::Stmt::If(node) => {
// we visit the if "test" condition first regardless
self.visit_expr(&node.test);

// create branch node: does the if test pass or not?
let if_branch = self.new_flow_node(FlowNode::Branch(BranchFlowNode {
predecessor: self.current_flow_node(),
}));

// visit the body of the `if` clause
self.set_current_flow_node(if_branch);
self.visit_body(&node.body);

// Flow node for the last if/elif condition branch; represents the "no branch
// taken yet" possibility (where "taking a branch" means that the condition in an
// if or elif evaluated to true and control flow went into that clause).
let mut prior_branch = if_branch;

// Flow node for the state after the prior if/elif/else clause; represents "we have
// taken one of the branches up to this point." Initially set to the post-if-clause
// state, later will be set to the phi node joining that possible path with the
// possibility that we took a later if/elif/else clause instead.
let mut post_prior_clause = self.current_flow_node();

// Flag to mark if the final clause is an "else" -- if so, that means the "match no
// clauses" path is not possible, we have to go through one of the clauses.
let mut last_branch_is_else = false;

for clause in &node.elif_else_clauses {
if clause.test.is_some() {
// This is an elif clause. Create a new branch node. Its predecessor is the
// previous branch node, because we can only take one branch in an entire
// if/elif/else chain, so if we take this branch, it can only be because we
// didn't take the previous one.
prior_branch = self.new_flow_node(FlowNode::Branch(BranchFlowNode {
predecessor: prior_branch,
}));
self.set_current_flow_node(prior_branch);
} else {
// This is an else clause. No need to create a branch node; there's no
// branch here, if we haven't taken any previous branch, we definitely go
// into the "else" clause.
self.set_current_flow_node(prior_branch);
last_branch_is_else = true;
}
self.visit_elif_else_clause(clause);
// Update `post_prior_clause` to a new phi node joining the possibility that we
// took any of the previous branches with the possibility that we took the one
// just visited.
post_prior_clause = self.new_flow_node(FlowNode::Phi(PhiFlowNode {
first_predecessor: self.current_flow_node(),
second_predecessor: post_prior_clause,
}));
}

if !last_branch_is_else {
// Final branch was not an "else", which means it's possible we took zero
// branches in the entire if/elif chain, so we need one more phi node to join
// the "no branches taken" possibility.
post_prior_clause = self.new_flow_node(FlowNode::Phi(PhiFlowNode {
first_predecessor: post_prior_clause,
second_predecessor: prior_branch,
}));
}

// Onward, with current flow node set to our final Phi node.
self.set_current_flow_node(post_prior_clause);
}
_ => {
ast::visitor::preorder::walk_stmt(self, stmt);
}
}
}
}

impl std::fmt::Display for FlowGraph {
fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
writeln!(f, "flowchart TD")?;
for (id, node) in self.flow_nodes_by_id.iter_enumerated() {
write!(f, " id{}", id.as_u32())?;
match node {
FlowNode::Start => writeln!(f, r"[\Start/]")?,
FlowNode::Definition(def_node) => {
writeln!(f, r"(Define symbol {})", def_node.symbol_id.as_u32())?;
writeln!(
f,
r" id{}-->id{}",
def_node.predecessor.as_u32(),
id.as_u32()
)?;
}
FlowNode::Branch(branch_node) => {
writeln!(f, r"{{Branch}}")?;
writeln!(
f,
r" id{}-->id{}",
branch_node.predecessor.as_u32(),
id.as_u32()
)?;
}
FlowNode::Phi(phi_node) => {
writeln!(f, r"((Phi))")?;
writeln!(
f,
r" id{}-->id{}",
phi_node.second_predecessor.as_u32(),
id.as_u32()
)?;
writeln!(
f,
r" id{}-->id{}",
phi_node.first_predecessor.as_u32(),
id.as_u32()
)?;
}
}
}
Ok(())
}
}

#[derive(Debug, Default)]
pub struct SymbolTablesStorage(KeyValueCache<FileId, Arc<SymbolTable>>);

Expand Down
110 changes: 87 additions & 23 deletions crates/red_knot/src/types/infer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,7 @@ pub fn infer_definition_type(
let file_id = symbol.file_id;

match definition {
Definition::None => Ok(Type::Unbound),
Definition::Import(ImportDefinition {
module: module_name,
}) => {
Expand Down Expand Up @@ -223,7 +224,7 @@ mod tests {
use crate::module::{
resolve_module, set_module_search_paths, ModuleName, ModuleSearchPath, ModuleSearchPathKind,
};
use crate::symbols::{resolve_global_symbol, symbol_table, GlobalSymbolId};
use crate::symbols::resolve_global_symbol;
use crate::types::{infer_symbol_public_type, Type};
use crate::Name;

Expand Down Expand Up @@ -399,30 +400,93 @@ mod tests {
#[test]
fn resolve_visible_def() -> anyhow::Result<()> {
let case = create_test()?;
let db = &case.db;

let path = case.src.path().join("a.py");
std::fs::write(path, "y = 1; y = 2; x = y")?;
let file = resolve_module(db, ModuleName::new("a"))?
.expect("module should be found")
.path(db)?
.file();
let symbols = symbol_table(db, file)?;
let x_sym = symbols
.root_symbol_id_by_name("x")
.expect("x symbol should be found");

let ty = infer_symbol_public_type(
db,
GlobalSymbolId {
file_id: file,
symbol_id: x_sym,
},
write_to_path(&case, "a.py", "y = 1; y = 2; x = y")?;

assert_public_type(&case, "a", "x", "Literal[2]")
}

#[test]
fn join_paths() -> anyhow::Result<()> {
let case = create_test()?;

write_to_path(
&case,
"a.py",
"
y = 1
y = 2
if flag:
y = 3
x = y
",
)?;

let jar = HasJar::<SemanticJar>::jar(db)?;
assert!(matches!(ty, Type::IntLiteral(_)));
assert_eq!(format!("{}", ty.display(&jar.type_store)), "Literal[2]");
Ok(())
assert_public_type(&case, "a", "x", "(Literal[2] | Literal[3])")
}

#[test]
fn maybe_unbound() -> anyhow::Result<()> {
let case = create_test()?;

write_to_path(
&case,
"a.py",
"
if flag:
y = 1
x = y
",
)?;

assert_public_type(&case, "a", "x", "(Unbound | Literal[1])")
}

#[test]
fn if_elif_else() -> anyhow::Result<()> {
let case = create_test()?;

write_to_path(
&case,
"a.py",
"
y = 1
y = 2
if flag:
y = 3
elif flag2:
y = 4
else:
r = y
y = 5
s = y
x = y
",
)?;

assert_public_type(&case, "a", "x", "(Literal[3] | Literal[4] | Literal[5])")?;
assert_public_type(&case, "a", "r", "Literal[2]")?;
assert_public_type(&case, "a", "s", "Literal[5]")
}

#[test]
fn if_elif() -> anyhow::Result<()> {
let case = create_test()?;

write_to_path(
&case,
"a.py",
"
y = 1
y = 2
if flag:
y = 3
elif flag2:
y = 4
x = y
",
)?;

assert_public_type(&case, "a", "x", "(Literal[2] | Literal[3] | Literal[4])")
}
}