Skip to content

Rust: Type inference for for loops and array expressions #19754

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 25 commits into from
Jun 24, 2025
Merged
Show file tree
Hide file tree
Changes from 6 commits
Commits
Show all changes
25 commits
Select commit Hold shift + click to select a range
840ef5c
Rust: Add test cases for type inference in loops.
geoffw0 Jun 12, 2025
f76b562
Rust: Implement type inference for 'for' loops on arrays.
geoffw0 Jun 12, 2025
51343a5
Rust: Implement type inference for ArrayListExprs.
geoffw0 Jun 13, 2025
b89d6d3
Rust: Implement type inference for ArrayRepeatExprs.
geoffw0 Jun 13, 2025
62e3cc5
Merge branch 'main' into typeinfer
geoffw0 Jun 13, 2025
6194676
Rust: Accept consistency failures (for now).
geoffw0 Jun 13, 2025
69da4e7
Rust: Move inferArrayExprType logic into typeEquality predicate.
geoffw0 Jun 17, 2025
66d6770
Rust: If we're inferring both ways, it should really be to any element.
geoffw0 Jun 17, 2025
4292b03
Rust: Add logic for Vecs and slices.
geoffw0 Jun 17, 2025
dec0deb
Rust: Add some more test cases for type inference on Vecs.
geoffw0 Jun 17, 2025
639f85a
Merge branch 'main' into typeinfer
geoffw0 Jun 19, 2025
1622d08
Rust: Add inferArrayExprType.
geoffw0 Jun 19, 2025
f670fcb
Rust: Add a Vec test case that we actually get (explicit type).
geoffw0 Jun 19, 2025
7170e97
Rust: Update test expectations format (type=...).
geoffw0 Jun 19, 2025
d55e8b7
Rust: Add another test case for ranges.
geoffw0 Jun 19, 2025
26e7b2d
Rust: Accept path resolution consistency changes.
geoffw0 Jun 19, 2025
7a25596
Merge branch 'main' into typeinfer
geoffw0 Jun 19, 2025
bfaabab
Rust: Update more expectations.
geoffw0 Jun 23, 2025
34cd976
Rust: Run rustfmt --edition 2024 on the test.
geoffw0 Jun 23, 2025
d02a728
Update rust/ql/lib/codeql/rust/internal/TypeInference.qll
geoffw0 Jun 23, 2025
8c848ac
Rust: Effects of rustfmt on .expected.
geoffw0 Jun 23, 2025
4530e85
Rust: Repair the test annotations.
geoffw0 Jun 23, 2025
530ded1
Merge branch 'main' into typeinfer
geoffw0 Jun 23, 2025
21bea7e
Merge branch 'main' into typeinfer
geoffw0 Jun 24, 2025
96dcdf9
Rust: Change note.
geoffw0 Jun 24, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
35 changes: 35 additions & 0 deletions rust/ql/lib/codeql/rust/internal/TypeInference.qll
Original file line number Diff line number Diff line change
Expand Up @@ -1124,6 +1124,37 @@ private Type inferIndexExprType(IndexExpr ie, TypePath path) {
)
}

pragma[nomagic]
private Type inferArrayExprType(ArrayExpr ae, TypePath path) {
// an array list expression (`[1, 2, 3]`) has the type of the first (any) element
exists(Type type0, TypePath path0 |
type0 = inferType(ae.(ArrayListExpr).getExpr(0), path0) and
result = type0 and
path = TypePath::cons(any(ArrayTypeParameter tp), path0)
)
or
// an array repeat expression (`[1; 3]`) has the type of the repeat operand
exists(Type type0, TypePath path0 |
type0 = inferType(ae.(ArrayRepeatExpr).getRepeatOperand(), path0) and
result = type0 and
path = TypePath::cons(any(ArrayTypeParameter tp), path0)
)
}

pragma[nomagic]
private Type inferForLoopExprType(AstNode n, TypePath path) {
// type of iterable -> type of pattern (loop variable)
exists(ForExpr fe, Type iterableType, TypePath iterablePath |
n = fe.getPat() and
iterableType = inferType(fe.getIterable(), iterablePath) and
(
iterablePath.isCons(any(ArrayTypeParameter tp), path) and
result = iterableType
// TODO: iterables (containers, ranges etc)
)
)
}

final class MethodCall extends Call {
MethodCall() {
exists(this.getReceiver()) and
Expand Down Expand Up @@ -1438,6 +1469,10 @@ private module Cached {
result = inferAwaitExprType(n, path)
or
result = inferIndexExprType(n, path)
or
result = inferArrayExprType(n, path)
or
result = inferForLoopExprType(n, path)
}
}

Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
multiplePathResolutions
| main.rs:1851:25:1851:36 | ...::from | file://:0:0:0:0 | fn from |
| main.rs:1851:25:1851:36 | ...::from | file://:0:0:0:0 | fn from |
| main.rs:1851:25:1851:36 | ...::from | file://:0:0:0:0 | fn from |
| main.rs:1851:25:1851:36 | ...::from | file://:0:0:0:0 | fn from |
| main.rs:1851:25:1851:36 | ...::from | file://:0:0:0:0 | fn from |
| main.rs:1851:25:1851:36 | ...::from | file://:0:0:0:0 | fn from |
| main.rs:1851:46:1851:57 | ...::from | file://:0:0:0:0 | fn from |
| main.rs:1851:46:1851:57 | ...::from | file://:0:0:0:0 | fn from |
| main.rs:1851:46:1851:57 | ...::from | file://:0:0:0:0 | fn from |
| main.rs:1851:46:1851:57 | ...::from | file://:0:0:0:0 | fn from |
| main.rs:1851:46:1851:57 | ...::from | file://:0:0:0:0 | fn from |
| main.rs:1851:46:1851:57 | ...::from | file://:0:0:0:0 | fn from |
| main.rs:1851:67:1851:78 | ...::from | file://:0:0:0:0 | fn from |
| main.rs:1851:67:1851:78 | ...::from | file://:0:0:0:0 | fn from |
| main.rs:1851:67:1851:78 | ...::from | file://:0:0:0:0 | fn from |
| main.rs:1851:67:1851:78 | ...::from | file://:0:0:0:0 | fn from |
| main.rs:1851:67:1851:78 | ...::from | file://:0:0:0:0 | fn from |
| main.rs:1851:67:1851:78 | ...::from | file://:0:0:0:0 | fn from |
| main.rs:1854:26:1854:37 | ...::from | file://:0:0:0:0 | fn from |
| main.rs:1854:26:1854:37 | ...::from | file://:0:0:0:0 | fn from |
| main.rs:1854:26:1854:37 | ...::from | file://:0:0:0:0 | fn from |
| main.rs:1854:26:1854:37 | ...::from | file://:0:0:0:0 | fn from |
| main.rs:1854:26:1854:37 | ...::from | file://:0:0:0:0 | fn from |
| main.rs:1854:26:1854:37 | ...::from | file://:0:0:0:0 | fn from |
| main.rs:1854:47:1854:58 | ...::from | file://:0:0:0:0 | fn from |
| main.rs:1854:47:1854:58 | ...::from | file://:0:0:0:0 | fn from |
| main.rs:1854:47:1854:58 | ...::from | file://:0:0:0:0 | fn from |
| main.rs:1854:47:1854:58 | ...::from | file://:0:0:0:0 | fn from |
| main.rs:1854:47:1854:58 | ...::from | file://:0:0:0:0 | fn from |
| main.rs:1854:47:1854:58 | ...::from | file://:0:0:0:0 | fn from |
| main.rs:1854:68:1854:79 | ...::from | file://:0:0:0:0 | fn from |
| main.rs:1854:68:1854:79 | ...::from | file://:0:0:0:0 | fn from |
| main.rs:1854:68:1854:79 | ...::from | file://:0:0:0:0 | fn from |
| main.rs:1854:68:1854:79 | ...::from | file://:0:0:0:0 | fn from |
| main.rs:1854:68:1854:79 | ...::from | file://:0:0:0:0 | fn from |
| main.rs:1854:68:1854:79 | ...::from | file://:0:0:0:0 | fn from |
89 changes: 89 additions & 0 deletions rust/ql/test/library-tests/type-inference/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1810,6 +1810,94 @@ mod indexers {
}
}

mod loops {
struct MyCallable {
}

impl MyCallable {
fn new() -> Self {
MyCallable {}
}

fn call(&self) -> i64 {
1
}
}

pub fn f() {
// for loops with arrays

for i in [1, 2, 3] { } // $ type=i:i32
for i in [1, 2, 3].map(|x| x + 1) { } // $ MISSING: type=i:i32
for i in [1, 2, 3].into_iter() { } // $ MISSING: type=i:i32

let vals1 = [1u8, 2, 3]; // $ MISSING: type=vals1:[u8; 3]
for u in vals1 { } // $ type=u:u8

let vals2 = [1u16; 3]; // $ MISSING: type=vals2:[u16; 3]
for u in vals2 { } // $ type=u:u16

let vals3: [u32; 3] = [1, 2, 3]; // $ MISSING: type=vals3:[u32; 3]
for u in vals3 { } // $ type=u:u32

let vals4: [u64; 3] = [1; 3]; // $ MISSING: type=vals4:[u64; 3]
for u in vals4 { } // $ type=u:u64

let mut strings1 = ["foo", "bar", "baz"]; // $ MISSING: type=strings1:[&str; 3]
for s in &strings1 { } // $ MISSING: type=s:&str
for s in &mut strings1 { } // $ MISSING: type=s:&str
for s in strings1 { } // $ type=s:str

let strings2 = [String::from("foo"), String::from("bar"), String::from("baz")]; // $ MISSING: type=strings2:[String; 3]
for s in strings2 { } // $ type=s:String

let strings3 = &[String::from("foo"), String::from("bar"), String::from("baz")]; // $ MISSING: type=strings3:&[String; 3]
for s in strings3 { } // $ MISSING: type=s:String

let callables = [MyCallable::new(), MyCallable::new(), MyCallable::new()]; // $ MISSING: type=callables:[MyCallable; 3]
for c in callables { // $ type=c:MyCallable
let result = c.call(); // $ type=result:i64 method=call
}

// for loops with ranges

for i in 0..10 { } // $ MISSING: type=i:i32
for u in [0u8 .. 10] { } // $ MISSING: type=u:u8

let range1 = std::ops::Range { start: 0u16, end: 10u16 }; // $ MISSING: type=range:std::ops::Range<u16>
for u in range1 { } // $ MISSING: type=i:u16

// for loops with containers

let vals3 = vec![1, 2, 3]; // MISSING: type=vals:Vec<i32>
for i in vals3 { } // $ MISSING: type=i:i32

let vals4 : Vec<&u64> = [1u64, 2, 3].iter().collect();
for u in vals4 { } // $ MISSING: type=u:&u64

let matrix1 = vec![vec![1, 2], vec![3, 4]]; // $ MISSING: type=vals5:Vec<Vec<i32>>
for row in matrix1 { // $ MISSING: type=row:Vec<i32>
for cell in row { // $ MISSING: type=cell:i32
}
}

let mut map1 = std::collections::HashMap::new(); // $ MISSING: type=map1:std::collections::HashMap<_, _>
map1.insert(1, Box::new("one")); // $ method=insert
map1.insert(2, Box::new("two")); // $ method=insert
for key in map1.keys() { } // $ method=keys MISSING: type=key:i32
for value in map1.values() { } // $ method=values MISSING: type=value:Box<&str>
for (key, value) in map1.iter() { } // $ method=iter MISSING: type=key:i32 type=value:Box<&str>
for (key, value) in &map1 { } // $ MISSING: type=key:i32 type=value:Box<&str>

// while loops

let mut a: i64 = 0; // $ type=a:i64
while a < 10 { // $ method=lt MISSING: type=a:i64m
a += 1; // $ type=a:i64 method=add_assign
}
}
}

fn main() {
field_access::f();
method_impl::f();
Expand All @@ -1832,4 +1920,5 @@ fn main() {
async_::f();
impl_trait::f();
indexers::f();
loops::f();
}
Loading