Skip to content

Commit 8bbcded

Browse files
committed
Auto merge of #123179 - scottmcm:inlining-baseline-costs, r=<try>
Rework MIR inlining costs A bunch of the current costs are surprising, probably accidentally from from not writing out the matches in full. For example, a runtime-length `memcpy` was treated as the same cost as an `Unreachable`. This reworks things around two main ideas: - Give everything a baseline cost, because even "free" things do take effort in the compiler (CPU & RAM) to MIR inline, and they're easy to calculate - Then just penalize those things that are materially more than the baseline, like how `[foo; 123]` is far more work than `BinOp::AddUnchecked` in an `Rvalue` By including costs for locals and vardebuginfo this makes some things overall more expensive, but because it also greatly reduces the cost for simple things like local variable addition, other things also become less expensive overall. r? ghost
2 parents ba52720 + 47a5a7f commit 8bbcded

File tree

2 files changed

+139
-19
lines changed

2 files changed

+139
-19
lines changed

compiler/rustc_mir_transform/src/cost_checker.rs

+128-19
Original file line numberDiff line numberDiff line change
@@ -2,10 +2,30 @@ use rustc_middle::mir::visit::*;
22
use rustc_middle::mir::*;
33
use rustc_middle::ty::{self, ParamEnv, Ty, TyCtxt};
44

5-
const INSTR_COST: usize = 5;
6-
const CALL_PENALTY: usize = 25;
7-
const LANDINGPAD_PENALTY: usize = 50;
8-
const RESUME_PENALTY: usize = 45;
5+
// Even if they're zero-cost at runtime, everything has *some* cost to inline
6+
// in terms of copying them into the MIR caller, processing them in codegen, etc.
7+
// These baseline costs give a simple usually-too-low estimate of the cost,
8+
// which will be updated afterwards to account for the "real" costs.
9+
const STMT_BASELINE_COST: usize = 1;
10+
const BLOCK_BASELINE_COST: usize = 3;
11+
const DEBUG_BASELINE_COST: usize = 1;
12+
const LOCAL_BASELINE_COST: usize = 1;
13+
14+
// These penalties represent the cost above baseline for those things which
15+
// have substantially more cost than is typical for their kind.
16+
const CALL_PENALTY: usize = 22;
17+
const LANDINGPAD_PENALTY: usize = 47;
18+
const RESUME_PENALTY: usize = 42;
19+
const DEREF_PENALTY: usize = 4;
20+
const CHECKED_OP_PENALTY: usize = 2;
21+
const THREAD_LOCAL_PENALTY: usize = 20;
22+
const SMALL_SWITCH_PENALTY: usize = 3;
23+
const LARGE_SWITCH_PENALTY: usize = 20;
24+
25+
// Passing arguments isn't free, so give a bonus to functions with lots of them:
26+
// if the body is small despite lots of arguments, some are probably unused.
27+
const EXTRA_ARG_BONUS: usize = 4;
28+
const MAX_ARG_BONUS: usize = CALL_PENALTY;
929

1030
/// Verify that the callee body is compatible with the caller.
1131
#[derive(Clone)]
@@ -27,6 +47,20 @@ impl<'b, 'tcx> CostChecker<'b, 'tcx> {
2747
CostChecker { tcx, param_env, callee_body, instance, cost: 0 }
2848
}
2949

50+
// `Inline` doesn't call `visit_body`, so this is separate from the visitor.
51+
pub fn before_body(&mut self, body: &Body<'tcx>) {
52+
self.cost += BLOCK_BASELINE_COST * body.basic_blocks.len();
53+
self.cost += DEBUG_BASELINE_COST * body.var_debug_info.len();
54+
self.cost += LOCAL_BASELINE_COST * body.local_decls.len();
55+
56+
let total_statements = body.basic_blocks.iter().map(|x| x.statements.len()).sum::<usize>();
57+
self.cost += STMT_BASELINE_COST * total_statements;
58+
59+
if let Some(extra_args) = body.arg_count.checked_sub(2) {
60+
self.cost = self.cost.saturating_sub((EXTRA_ARG_BONUS * extra_args).min(MAX_ARG_BONUS));
61+
}
62+
}
63+
3064
pub fn cost(&self) -> usize {
3165
self.cost
3266
}
@@ -41,14 +75,70 @@ impl<'b, 'tcx> CostChecker<'b, 'tcx> {
4175
}
4276

4377
impl<'tcx> Visitor<'tcx> for CostChecker<'_, 'tcx> {
44-
fn visit_statement(&mut self, statement: &Statement<'tcx>, _: Location) {
45-
// Don't count StorageLive/StorageDead in the inlining cost.
46-
match statement.kind {
47-
StatementKind::StorageLive(_)
78+
fn visit_statement(&mut self, statement: &Statement<'tcx>, location: Location) {
79+
match &statement.kind {
80+
StatementKind::Assign(place_and_rvalue) => {
81+
if place_and_rvalue.0.is_indirect_first_projection() {
82+
self.cost += DEREF_PENALTY;
83+
}
84+
self.visit_rvalue(&place_and_rvalue.1, location);
85+
}
86+
StatementKind::Intrinsic(intr) => match &**intr {
87+
NonDivergingIntrinsic::Assume(..) => {}
88+
NonDivergingIntrinsic::CopyNonOverlapping(_cno) => {
89+
self.cost += CALL_PENALTY;
90+
}
91+
},
92+
StatementKind::FakeRead(..)
93+
| StatementKind::SetDiscriminant { .. }
94+
| StatementKind::StorageLive(_)
4895
| StatementKind::StorageDead(_)
96+
| StatementKind::Retag(..)
97+
| StatementKind::PlaceMention(..)
98+
| StatementKind::AscribeUserType(..)
99+
| StatementKind::Coverage(..)
49100
| StatementKind::Deinit(_)
50-
| StatementKind::Nop => {}
51-
_ => self.cost += INSTR_COST,
101+
| StatementKind::ConstEvalCounter
102+
| StatementKind::Nop => {
103+
// No extra cost for these
104+
}
105+
}
106+
}
107+
108+
fn visit_rvalue(&mut self, rvalue: &Rvalue<'tcx>, _location: Location) {
109+
match rvalue {
110+
Rvalue::Use(operand) => {
111+
if let Some(place) = operand.place()
112+
&& place.is_indirect_first_projection()
113+
{
114+
self.cost += DEREF_PENALTY;
115+
}
116+
}
117+
Rvalue::Repeat(_item, count) => {
118+
let count = count.try_to_target_usize(self.tcx).unwrap_or(u64::MAX);
119+
self.cost += (STMT_BASELINE_COST * count as usize).min(CALL_PENALTY);
120+
}
121+
Rvalue::Aggregate(_kind, fields) => {
122+
self.cost += STMT_BASELINE_COST * fields.len();
123+
}
124+
Rvalue::CheckedBinaryOp(..) => {
125+
self.cost += CHECKED_OP_PENALTY;
126+
}
127+
Rvalue::ThreadLocalRef(..) => {
128+
self.cost += THREAD_LOCAL_PENALTY;
129+
}
130+
Rvalue::Ref(..)
131+
| Rvalue::AddressOf(..)
132+
| Rvalue::Len(..)
133+
| Rvalue::Cast(..)
134+
| Rvalue::BinaryOp(..)
135+
| Rvalue::NullaryOp(..)
136+
| Rvalue::UnaryOp(..)
137+
| Rvalue::Discriminant(..)
138+
| Rvalue::ShallowInitBox(..)
139+
| Rvalue::CopyForDeref(..) => {
140+
// No extra cost for these
141+
}
52142
}
53143
}
54144

@@ -63,24 +153,35 @@ impl<'tcx> Visitor<'tcx> for CostChecker<'_, 'tcx> {
63153
if let UnwindAction::Cleanup(_) = unwind {
64154
self.cost += LANDINGPAD_PENALTY;
65155
}
66-
} else {
67-
self.cost += INSTR_COST;
68156
}
69157
}
70-
TerminatorKind::Call { func: Operand::Constant(ref f), unwind, .. } => {
71-
let fn_ty = self.instantiate_ty(f.const_.ty());
72-
self.cost += if let ty::FnDef(def_id, _) = *fn_ty.kind()
158+
TerminatorKind::Call { ref func, unwind, .. } => {
159+
if let Some(f) = func.constant()
160+
&& let fn_ty = self.instantiate_ty(f.ty())
161+
&& let ty::FnDef(def_id, _) = *fn_ty.kind()
73162
&& tcx.intrinsic(def_id).is_some()
74163
{
75164
// Don't give intrinsics the extra penalty for calls
76-
INSTR_COST
77165
} else {
78-
CALL_PENALTY
166+
self.cost += CALL_PENALTY;
79167
};
80168
if let UnwindAction::Cleanup(_) = unwind {
81169
self.cost += LANDINGPAD_PENALTY;
82170
}
83171
}
172+
TerminatorKind::SwitchInt { ref discr, ref targets } => {
173+
if let Operand::Constant(..) = discr {
174+
// This'll be a goto once we're monomorphizing
175+
} else {
176+
// 0/1/unreachable is extremely common (bool, Option, Result, ...)
177+
// but once there's more this can be a fair bit of work.
178+
self.cost += if targets.all_targets().len() <= 3 {
179+
SMALL_SWITCH_PENALTY
180+
} else {
181+
LARGE_SWITCH_PENALTY
182+
};
183+
}
184+
}
84185
TerminatorKind::Assert { unwind, .. } => {
85186
self.cost += CALL_PENALTY;
86187
if let UnwindAction::Cleanup(_) = unwind {
@@ -89,12 +190,20 @@ impl<'tcx> Visitor<'tcx> for CostChecker<'_, 'tcx> {
89190
}
90191
TerminatorKind::UnwindResume => self.cost += RESUME_PENALTY,
91192
TerminatorKind::InlineAsm { unwind, .. } => {
92-
self.cost += INSTR_COST;
93193
if let UnwindAction::Cleanup(_) = unwind {
94194
self.cost += LANDINGPAD_PENALTY;
95195
}
96196
}
97-
_ => self.cost += INSTR_COST,
197+
TerminatorKind::Goto { .. }
198+
| TerminatorKind::UnwindTerminate(..)
199+
| TerminatorKind::Return
200+
| TerminatorKind::Yield { .. }
201+
| TerminatorKind::CoroutineDrop
202+
| TerminatorKind::FalseEdge { .. }
203+
| TerminatorKind::FalseUnwind { .. }
204+
| TerminatorKind::Unreachable => {
205+
// No extra cost for these
206+
}
98207
}
99208
}
100209
}

compiler/rustc_mir_transform/src/inline.rs

+11
Original file line numberDiff line numberDiff line change
@@ -506,6 +506,17 @@ impl<'tcx> Inliner<'tcx> {
506506
let mut checker =
507507
CostChecker::new(self.tcx, self.param_env, Some(callsite.callee), callee_body);
508508

509+
checker.before_body(callee_body);
510+
511+
let baseline_cost = checker.cost();
512+
if baseline_cost > threshold {
513+
debug!(
514+
"NOT inlining {:?} [baseline_cost={} > threshold={}]",
515+
callsite, baseline_cost, threshold
516+
);
517+
return Err("baseline_cost above threshold");
518+
}
519+
509520
// Traverse the MIR manually so we can account for the effects of inlining on the CFG.
510521
let mut work_list = vec![START_BLOCK];
511522
let mut visited = BitSet::new_empty(callee_body.basic_blocks.len());

0 commit comments

Comments
 (0)