Skip to content

Rust: Type inference for for loops and array expressions #19754

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 10 commits into
base: main
Choose a base branch
from
Draft
33 changes: 33 additions & 0 deletions rust/ql/lib/codeql/rust/internal/TypeInference.qll
Original file line number Diff line number Diff line change
Expand Up @@ -287,6 +287,16 @@ private predicate typeEquality(AstNode n1, TypePath prefix1, AstNode n2, TypePat
prefix2.isEmpty()
)
)
or
// an array list expression (`[1, 2, 3]`) has the type of the first (any) element
n1.(ArrayListExpr).getExpr(_) = n2 and
prefix1 = TypePath::singleton(TArrayTypeParameter()) and
prefix2.isEmpty()
or
// an array repeat expression (`[1; 3]`) has the type of the repeat operand
n1.(ArrayRepeatExpr).getRepeatOperand() = n2 and
prefix1 = TypePath::singleton(TArrayTypeParameter()) and
prefix2.isEmpty()
}

pragma[nomagic]
Expand Down Expand Up @@ -1124,6 +1134,27 @@ private Type inferIndexExprType(IndexExpr ie, TypePath path) {
)
}

pragma[nomagic]
private Type inferForLoopExprType(AstNode n, TypePath path) {
// type of iterable -> type of pattern (loop variable)
exists(ForExpr fe, Type iterableType, TypePath iterablePath |
n = fe.getPat() and
iterableType = inferType(fe.getIterable(), iterablePath) and
result = iterableType and
(
iterablePath.isCons(any(Vec v).getElementTypeParameter(), path)
or
iterablePath.isCons(any(ArrayTypeParameter tp), path)
or
exists(TypePath path0 |
iterablePath.isCons(any(RefTypeParameter tp), path0) and
path0.isCons(any(SliceTypeParameter tp), path)
)
// TODO: iterables (general case for containers, ranges etc)
)
)
}

final class MethodCall extends Call {
MethodCall() {
exists(this.getReceiver()) and
Expand Down Expand Up @@ -1438,6 +1469,8 @@ private module Cached {
result = inferAwaitExprType(n, path)
or
result = inferIndexExprType(n, path)
or
result = inferForLoopExprType(n, path)
}
}

Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
multiplePathResolutions
| main.rs:1851:25:1851:36 | ...::from | file://:0:0:0:0 | fn from |
| main.rs:1851:25:1851:36 | ...::from | file://:0:0:0:0 | fn from |
| main.rs:1851:25:1851:36 | ...::from | file://:0:0:0:0 | fn from |
| main.rs:1851:25:1851:36 | ...::from | file://:0:0:0:0 | fn from |
| main.rs:1851:25:1851:36 | ...::from | file://:0:0:0:0 | fn from |
| main.rs:1851:25:1851:36 | ...::from | file://:0:0:0:0 | fn from |
| main.rs:1851:46:1851:57 | ...::from | file://:0:0:0:0 | fn from |
| main.rs:1851:46:1851:57 | ...::from | file://:0:0:0:0 | fn from |
| main.rs:1851:46:1851:57 | ...::from | file://:0:0:0:0 | fn from |
| main.rs:1851:46:1851:57 | ...::from | file://:0:0:0:0 | fn from |
| main.rs:1851:46:1851:57 | ...::from | file://:0:0:0:0 | fn from |
| main.rs:1851:46:1851:57 | ...::from | file://:0:0:0:0 | fn from |
| main.rs:1851:67:1851:78 | ...::from | file://:0:0:0:0 | fn from |
| main.rs:1851:67:1851:78 | ...::from | file://:0:0:0:0 | fn from |
| main.rs:1851:67:1851:78 | ...::from | file://:0:0:0:0 | fn from |
| main.rs:1851:67:1851:78 | ...::from | file://:0:0:0:0 | fn from |
| main.rs:1851:67:1851:78 | ...::from | file://:0:0:0:0 | fn from |
| main.rs:1851:67:1851:78 | ...::from | file://:0:0:0:0 | fn from |
| main.rs:1854:26:1854:37 | ...::from | file://:0:0:0:0 | fn from |
| main.rs:1854:26:1854:37 | ...::from | file://:0:0:0:0 | fn from |
| main.rs:1854:26:1854:37 | ...::from | file://:0:0:0:0 | fn from |
| main.rs:1854:26:1854:37 | ...::from | file://:0:0:0:0 | fn from |
| main.rs:1854:26:1854:37 | ...::from | file://:0:0:0:0 | fn from |
| main.rs:1854:26:1854:37 | ...::from | file://:0:0:0:0 | fn from |
| main.rs:1854:47:1854:58 | ...::from | file://:0:0:0:0 | fn from |
| main.rs:1854:47:1854:58 | ...::from | file://:0:0:0:0 | fn from |
| main.rs:1854:47:1854:58 | ...::from | file://:0:0:0:0 | fn from |
| main.rs:1854:47:1854:58 | ...::from | file://:0:0:0:0 | fn from |
| main.rs:1854:47:1854:58 | ...::from | file://:0:0:0:0 | fn from |
| main.rs:1854:47:1854:58 | ...::from | file://:0:0:0:0 | fn from |
| main.rs:1854:68:1854:79 | ...::from | file://:0:0:0:0 | fn from |
| main.rs:1854:68:1854:79 | ...::from | file://:0:0:0:0 | fn from |
| main.rs:1854:68:1854:79 | ...::from | file://:0:0:0:0 | fn from |
| main.rs:1854:68:1854:79 | ...::from | file://:0:0:0:0 | fn from |
| main.rs:1854:68:1854:79 | ...::from | file://:0:0:0:0 | fn from |
| main.rs:1854:68:1854:79 | ...::from | file://:0:0:0:0 | fn from |
99 changes: 99 additions & 0 deletions rust/ql/test/library-tests/type-inference/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1810,6 +1810,104 @@ mod indexers {
}
}

mod loops {
struct MyCallable {
}

impl MyCallable {
fn new() -> Self {
MyCallable {}
}

fn call(&self) -> i64 {
1
}
}

pub fn f() {
// for loops with arrays

for i in [1, 2, 3] { } // $ type=i:i32
for i in [1, 2, 3].map(|x| x + 1) { } // $ MISSING: type=i:i32
for i in [1, 2, 3].into_iter() { } // $ MISSING: type=i:i32

let vals1 = [1u8, 2, 3]; // $ MISSING: type=vals1:[u8; 3]
for u in vals1 { } // $ type=u:u8

let vals2 = [1u16; 3]; // $ MISSING: type=vals2:[u16; 3]
for u in vals2 { } // $ type=u:u16

let vals3: [u32; 3] = [1, 2, 3]; // $ MISSING: type=vals3:[u32; 3]
for u in vals3 { } // $ type=u:u32

let vals4: [u64; 3] = [1; 3]; // $ MISSING: type=vals4:[u64; 3]
for u in vals4 { } // $ type=u:u64

let mut strings1 = ["foo", "bar", "baz"]; // $ MISSING: type=strings1:[&str; 3]
for s in &strings1 { } // $ MISSING: type=s:&str
for s in &mut strings1 { } // $ MISSING: type=s:&str
for s in strings1 { } // $ type=s:str

let strings2 = [String::from("foo"), String::from("bar"), String::from("baz")]; // $ MISSING: type=strings2:[String; 3]
for s in strings2 { } // $ type=s:String

let strings3 = &[String::from("foo"), String::from("bar"), String::from("baz")]; // $ MISSING: type=strings3:&[String; 3]
for s in strings3 { } // $ MISSING: type=s:String

let callables = [MyCallable::new(), MyCallable::new(), MyCallable::new()]; // $ MISSING: type=callables:[MyCallable; 3]
for c in callables { // $ type=c:MyCallable
let result = c.call(); // $ type=result:i64 method=call
}

// for loops with ranges

for i in 0..10 { } // $ MISSING: type=i:i32
for u in [0u8 .. 10] { } // $ MISSING: type=u:u8

let range1 = std::ops::Range { start: 0u16, end: 10u16 }; // $ MISSING: type=range:std::ops::Range<u16>
for u in range1 { } // $ MISSING: type=i:u16

// for loops with containers

let vals3 = vec![1, 2, 3]; // $ MISSING: type=vals3:Vec<i32>
for i in vals3 { } // $ MISSING: type=i:i32

let vals4 = [1u16, 2, 3].to_vec(); // $ MISSING: type=vals4:Vec<u16>
for u in vals4 { } // $ MISSING: type=u:u16

let vals5 = Vec::from([1u32, 2, 3]); // $ MISSING: type=vals5:Vec<u32>
for u in vals5 { } // $ MISSING: type=u:u32

let vals6 : Vec<&u64> = [1u64, 2, 3].iter().collect(); // $ MISSING: type=vals6:Vec<&u64>
for u in vals6 { } // $ MISSING: type=u:&u64

let mut vals7 = Vec::new(); // $ MISSING: type=vals7:Vec<u8>
vals7.push(1u8); // $ method=push
for u in vals7 { } // $ MISSING: type=u:u8

let matrix1 = vec![vec![1, 2], vec![3, 4]]; // $ MISSING: type=vals5:Vec<Vec<i32>>
for row in matrix1 { // $ MISSING: type=row:Vec<i32>
for cell in row { // $ MISSING: type=cell:i32
}
}

let mut map1 = std::collections::HashMap::new(); // $ MISSING: type=map1:std::collections::HashMap<_, _>
map1.insert(1, Box::new("one")); // $ method=insert
map1.insert(2, Box::new("two")); // $ method=insert
for key in map1.keys() { } // $ method=keys MISSING: type=key:i32
for value in map1.values() { } // $ method=values MISSING: type=value:Box<&str>
for (key, value) in map1.iter() { } // $ method=iter MISSING: type=key:i32 type=value:Box<&str>
for (key, value) in &map1 { } // $ MISSING: type=key:i32 type=value:Box<&str>

// while loops

let mut a: i64 = 0; // $ type=a:i64
while a < 10 { // $ method=lt MISSING: type=a:i64m
a += 1; // $ type=a:i64 method=add_assign
}
}
}

fn main() {
field_access::f();
method_impl::f();
Expand All @@ -1832,4 +1930,5 @@ fn main() {
async_::f();
impl_trait::f();
indexers::f();
loops::f();
}
Loading