forked from rust-lang/rust
-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathssa_range_prop.rs
More file actions
219 lines (202 loc) · 8.64 KB
/
ssa_range_prop.rs
File metadata and controls
219 lines (202 loc) · 8.64 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
//! A pass that propagates the known ranges of SSA locals.
//! We can know the ranges of SSA locals in certain locations for the following code:
//! ```
//! fn foo(a: u32) {
//! let b = a < 9; // the integer representation of b is within the full range [0, 2).
//! if b {
//! let c = b; // c is true since b is within the range [1, 2).
//! let d = a < 8; // d is true since a is within the range [0, 9).
//! }
//! }
//! ```
use rustc_abi::WrappingRange;
use rustc_const_eval::interpret::Scalar;
use rustc_data_structures::fx::FxHashMap;
use rustc_data_structures::graph::dominators::Dominators;
use rustc_index::bit_set::DenseBitSet;
use rustc_middle::mir::visit::MutVisitor;
use rustc_middle::mir::{BasicBlock, Body, Location, Operand, Place, TerminatorKind, *};
use rustc_middle::ty::{TyCtxt, TypingEnv};
use rustc_span::DUMMY_SP;
use crate::ssa::SsaLocals;
pub(super) struct SsaRangePropagation;
impl<'tcx> crate::MirPass<'tcx> for SsaRangePropagation {
fn is_enabled(&self, sess: &rustc_session::Session) -> bool {
sess.mir_opt_level() > 1
}
fn run_pass(&self, tcx: TyCtxt<'tcx>, body: &mut Body<'tcx>) {
let typing_env = body.typing_env(tcx);
let ssa = SsaLocals::new(tcx, body, typing_env);
// Clone dominators because we need them while mutating the body.
let dominators = body.basic_blocks.dominators().clone();
let mut range_set =
RangeSet::new(tcx, typing_env, body, &ssa, &body.local_decls, dominators);
let reverse_postorder = body.basic_blocks.reverse_postorder().to_vec();
for bb in reverse_postorder {
let data = &mut body.basic_blocks.as_mut_preserves_cfg()[bb];
range_set.visit_basic_block_data(bb, data);
}
}
fn is_required(&self) -> bool {
false
}
}
struct RangeSet<'tcx, 'body, 'a> {
tcx: TyCtxt<'tcx>,
typing_env: TypingEnv<'tcx>,
ssa: &'a SsaLocals,
local_decls: &'body LocalDecls<'tcx>,
dominators: Dominators<BasicBlock>,
/// Known ranges at each locations.
ranges: FxHashMap<Place<'tcx>, Vec<(Location, WrappingRange)>>,
/// Determines if the basic block has a single unique predecessor.
unique_predecessors: DenseBitSet<BasicBlock>,
}
impl<'tcx, 'body, 'a> RangeSet<'tcx, 'body, 'a> {
fn new(
tcx: TyCtxt<'tcx>,
typing_env: TypingEnv<'tcx>,
body: &Body<'tcx>,
ssa: &'a SsaLocals,
local_decls: &'body LocalDecls<'tcx>,
dominators: Dominators<BasicBlock>,
) -> Self {
let predecessors = body.basic_blocks.predecessors();
let mut unique_predecessors = DenseBitSet::new_empty(body.basic_blocks.len());
for bb in body.basic_blocks.indices() {
if predecessors[bb].len() == 1 {
unique_predecessors.insert(bb);
}
}
RangeSet {
tcx,
typing_env,
ssa,
local_decls,
dominators,
ranges: FxHashMap::default(),
unique_predecessors,
}
}
/// Create a new known range at the location.
fn insert_range(&mut self, place: Place<'tcx>, location: Location, range: WrappingRange) {
assert!(self.is_ssa(place));
self.ranges.entry(place).or_default().push((location, range));
}
/// Get the known range at the location.
fn get_range(&self, place: &Place<'tcx>, location: Location) -> Option<WrappingRange> {
let Some(ranges) = self.ranges.get(place) else {
return None;
};
// FIXME: This should use the intersection of all valid ranges.
let (_, range) =
ranges.iter().find(|(range_loc, _)| range_loc.dominates(location, &self.dominators))?;
Some(*range)
}
fn try_as_constant(
&mut self,
place: Place<'tcx>,
location: Location,
) -> Option<ConstOperand<'tcx>> {
if let Some(range) = self.get_range(&place, location)
&& range.start == range.end
{
let ty = place.ty(self.local_decls, self.tcx).ty;
let layout = self.tcx.layout_of(self.typing_env.as_query_input(ty)).ok()?;
let value = ConstValue::Scalar(Scalar::from_uint(range.start, layout.size));
let const_ = Const::Val(value, ty);
return Some(ConstOperand { span: DUMMY_SP, user_ty: None, const_ });
}
None
}
fn is_ssa(&self, place: Place<'tcx>) -> bool {
self.ssa.is_ssa(place.local) && place.is_stable_offset()
}
}
impl<'tcx> MutVisitor<'tcx> for RangeSet<'tcx, '_, '_> {
fn tcx(&self) -> TyCtxt<'tcx> {
self.tcx
}
fn visit_operand(&mut self, operand: &mut Operand<'tcx>, location: Location) {
// Attempts to simplify an operand to a constant value.
if let Some(place) = operand.place()
&& let Some(const_) = self.try_as_constant(place, location)
{
*operand = Operand::Constant(Box::new(const_));
};
}
fn visit_statement(&mut self, statement: &mut Statement<'tcx>, location: Location) {
self.super_statement(statement, location);
match &statement.kind {
StatementKind::Intrinsic(box NonDivergingIntrinsic::Assume(operand)) => {
if let Some(place) = operand.place()
&& self.is_ssa(place)
{
let successor = location.successor_within_block();
let range = WrappingRange { start: 1, end: 1 };
self.insert_range(place, successor, range);
}
}
_ => {}
}
}
fn visit_terminator(&mut self, terminator: &mut Terminator<'tcx>, location: Location) {
self.super_terminator(terminator, location);
match &terminator.kind {
TerminatorKind::Assert { cond, expected, target, .. } => {
if let Some(place) = cond.place()
&& self.is_ssa(place)
{
let successor = Location { block: *target, statement_index: 0 };
if location.dominates(successor, &self.dominators) {
assert_ne!(location.block, successor.block);
let val = *expected as u128;
let range = WrappingRange { start: val, end: val };
self.insert_range(place, successor, range);
}
}
}
TerminatorKind::SwitchInt { discr, targets } => {
if let Some(place) = discr.place()
&& self.is_ssa(place)
// Reduce the potential compile-time overhead.
&& targets.all_targets().len() < 16
{
let mut distinct_targets: FxHashMap<BasicBlock, u64> = FxHashMap::default();
for (_, target) in targets.iter() {
let targets = distinct_targets.entry(target).or_default();
*targets += 1;
}
for (val, target) in targets.iter() {
if distinct_targets[&target] != 1 {
// FIXME: For multiple targets, the range can be the union of their values.
continue;
}
let successor = Location { block: target, statement_index: 0 };
if self.unique_predecessors.contains(successor.block) {
assert_ne!(location.block, successor.block);
let range = WrappingRange { start: val, end: val };
self.insert_range(place, successor, range);
}
}
// FIXME: The range for the otherwise target be extend to more types.
// For instance, `val` is within the range [4, 1) at the otherwise target of `matches!(val, 1 | 2 | 3)`.
let otherwise = Location { block: targets.otherwise(), statement_index: 0 };
if place.ty(self.local_decls, self.tcx).ty.is_bool()
&& let [val] = targets.all_values()
&& self.unique_predecessors.contains(otherwise.block)
{
assert_ne!(location.block, otherwise.block);
let range = if val.get() == 0 {
WrappingRange { start: 1, end: 1 }
} else {
WrappingRange { start: 0, end: 0 }
};
self.insert_range(place, otherwise, range);
}
}
}
_ => {}
}
}
}