Skip to content

Commit 2ea6288

Browse files
authored
[wasm] Jiterpreter monitoring phase take 2 (#83489)
* Add a monitoring phase to jiterpreter traces, that monitors an approximate number of opcodes executed before specific types of bailouts. If a trace bails out frequently without executing enough opcodes, it will be rejected and turned into a nop to improve performance. * Fix assert when running out of TraceInfo space
1 parent 6047cdc commit 2ea6288

File tree

9 files changed

+220
-35
lines changed

9 files changed

+220
-35
lines changed

src/mono/mono/mini/interp/interp.c

Lines changed: 17 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -3798,6 +3798,11 @@ max_d (double lhs, double rhs)
37983798
return fmax (lhs, rhs);
37993799
}
38003800

3801+
#if HOST_BROWSER
3802+
// Dummy call info used outside of monitoring phase. We don't care what's in it
3803+
static JiterpreterCallInfo jiterpreter_call_info = { 0 };
3804+
#endif
3805+
38013806
/*
38023807
* If CLAUSE_ARGS is non-null, start executing from it.
38033808
* The ERROR argument is used to avoid declaring an error object for every interp frame, its not used
@@ -7782,15 +7787,11 @@ MINT_IN_CASE(MINT_BRTRUE_I8_SP) ZEROP_SP(gint64, !=); MINT_IN_BREAK;
77827787
* (note that right now threading doesn't work, but it's worth being correct
77837788
* here so that implementing thread support will be easier later.)
77847789
*/
7785-
*mutable_ip = MINT_TIER_NOP_JITERPRETER;
7786-
mono_memory_barrier ();
7787-
*(volatile JiterpreterThunk*)(ip + 1) = prepare_result;
7788-
mono_memory_barrier ();
7789-
*mutable_ip = MINT_TIER_ENTER_JITERPRETER;
7790+
*mutable_ip = MINT_TIER_MONITOR_JITERPRETER;
77907791
// now execute the trace
77917792
// this isn't important for performance, but it makes it easier to use the
77927793
// jiterpreter early in automated tests where code only runs once
7793-
offset = prepare_result(frame, locals);
7794+
offset = prepare_result(frame, locals, &jiterpreter_call_info);
77947795
ip = (guint16*) (((guint8*)ip) + offset);
77957796
break;
77967797
}
@@ -7801,9 +7802,18 @@ MINT_IN_CASE(MINT_BRTRUE_I8_SP) ZEROP_SP(gint64, !=); MINT_IN_BREAK;
78017802
MINT_IN_BREAK;
78027803
}
78037804

7805+
MINT_IN_CASE(MINT_TIER_MONITOR_JITERPRETER) {
7806+
// The trace is in monitoring mode, where we track how far it actually goes
7807+
// each time it is executed for a while. After N more hits, we either
7808+
// turn it into an ENTER or a NOP depending on how well it is working
7809+
ptrdiff_t offset = mono_jiterp_monitor_trace (ip, frame, locals);
7810+
ip = (guint16*) (((guint8*)ip) + offset);
7811+
MINT_IN_BREAK;
7812+
}
7813+
78047814
MINT_IN_CASE(MINT_TIER_ENTER_JITERPRETER) {
78057815
JiterpreterThunk thunk = (void*)READ32(ip + 1);
7806-
ptrdiff_t offset = thunk(frame, locals);
7816+
ptrdiff_t offset = thunk(frame, locals, &jiterpreter_call_info);
78077817
ip = (guint16*) (((guint8*)ip) + offset);
78087818
MINT_IN_BREAK;
78097819
}

src/mono/mono/mini/interp/jiterpreter.c

Lines changed: 107 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -877,13 +877,17 @@ typedef struct {
877877
// 64-bits because it can get very high if estimate heat is turned on
878878
gint64 hit_count;
879879
JiterpreterThunk thunk;
880+
int penalty_total;
880881
} TraceInfo;
881882

882-
#define MAX_TRACE_SEGMENTS 256
883+
// The maximum number of trace segments used to store TraceInfo. This limits
884+
// the maximum total number of traces to MAX_TRACE_SEGMENTS * TRACE_SEGMENT_SIZE
885+
#define MAX_TRACE_SEGMENTS 1024
883886
#define TRACE_SEGMENT_SIZE 1024
884887

885888
static volatile gint32 trace_count = 0;
886889
static TraceInfo *trace_segments[MAX_TRACE_SEGMENTS] = { NULL };
890+
static gint32 traces_rejected = 0;
887891

888892
static TraceInfo *
889893
trace_info_allocate_segment (gint32 index) {
@@ -917,7 +921,14 @@ trace_info_get (gint32 index) {
917921

918922
static gint32
919923
trace_info_alloc () {
920-
gint32 index = trace_count++;
924+
gint32 index = trace_count++,
925+
limit = (MAX_TRACE_SEGMENTS * TRACE_SEGMENT_SIZE);
926+
// Make sure we're not out of space in the trace info table.
927+
if (index == limit)
928+
g_print ("MONO_WASM: Reached maximum number of jiterpreter trace entry points (%d).\n", limit);
929+
if (index >= limit)
930+
return -1;
931+
921932
TraceInfo *info = trace_info_get (index);
922933
info->hit_count = 0;
923934
info->thunk = NULL;
@@ -984,20 +995,24 @@ jiterp_insert_entry_points (void *_imethod, void *_td)
984995

985996
if (enabled && should_generate) {
986997
gint32 trace_index = trace_info_alloc ();
987-
988-
td->cbb = bb;
989-
imethod->contains_traces = TRUE;
990-
InterpInst *ins = mono_jiterp_insert_ins (td, NULL, MINT_TIER_PREPARE_JITERPRETER);
991-
memcpy(ins->data, &trace_index, sizeof (trace_index));
992-
993-
// Clear the instruction counter
994-
instruction_count = 0;
995-
996-
// Note that we only clear enter_at_next here, after generating a trace.
997-
// This means that the flag will stay set intentionally if we keep failing
998-
// to generate traces, perhaps due to a string of small basic blocks
999-
// or multiple call instructions.
1000-
enter_at_next = bb->contains_call_instruction;
998+
if (trace_index < 0) {
999+
// We're out of space in the TraceInfo table.
1000+
return;
1001+
} else {
1002+
td->cbb = bb;
1003+
imethod->contains_traces = TRUE;
1004+
InterpInst *ins = mono_jiterp_insert_ins (td, NULL, MINT_TIER_PREPARE_JITERPRETER);
1005+
memcpy(ins->data, &trace_index, sizeof (trace_index));
1006+
1007+
// Clear the instruction counter
1008+
instruction_count = 0;
1009+
1010+
// Note that we only clear enter_at_next here, after generating a trace.
1011+
// This means that the flag will stay set intentionally if we keep failing
1012+
// to generate traces, perhaps due to a string of small basic blocks
1013+
// or multiple call instructions.
1014+
enter_at_next = bb->contains_call_instruction;
1015+
}
10011016
} else if (is_backwards_branch && enabled && !should_generate) {
10021017
// We failed to start a trace at a backwards branch target, but that might just mean
10031018
// that the loop body starts with one or two unsupported opcodes, so it may be
@@ -1233,7 +1248,7 @@ mono_jiterp_stelem_ref (
12331248

12341249
EMSCRIPTEN_KEEPALIVE int
12351250
mono_jiterp_trace_transfer (
1236-
int displacement, JiterpreterThunk trace, void *frame, void *pLocals
1251+
int displacement, JiterpreterThunk trace, void *frame, void *pLocals, JiterpreterCallInfo *cinfo
12371252
) {
12381253
// This indicates that we lost a race condition, so there's no trace to call. Just bail out.
12391254
// FIXME: Detect this at trace generation time and spin until the trace is available
@@ -1245,7 +1260,7 @@ mono_jiterp_trace_transfer (
12451260
// safepoint was already performed by the trace.
12461261
int relative_displacement = 0;
12471262
while (relative_displacement == 0)
1248-
relative_displacement = trace(frame, pLocals);
1263+
relative_displacement = trace(frame, pLocals, cinfo);
12491264

12501265
// We got a relative displacement other than 0, so the trace bailed out somewhere or
12511266
// branched to another branch target. Time to return (and our caller will return too.)
@@ -1326,6 +1341,80 @@ mono_jiterp_write_number_unaligned (void *dest, double value, int mode) {
13261341
}
13271342
}
13281343

1344+
#define TRACE_PENALTY_LIMIT 200
1345+
#define TRACE_MONITORING_DETAILED FALSE
1346+
1347+
ptrdiff_t
1348+
mono_jiterp_monitor_trace (const guint16 *ip, void *_frame, void *locals)
1349+
{
1350+
gint32 index = READ32 (ip + 1);
1351+
TraceInfo *info = trace_info_get (index);
1352+
g_assert (info);
1353+
1354+
JiterpreterThunk thunk = info->thunk;
1355+
// FIXME: This shouldn't be possible
1356+
g_assert (((guint32)(void *)thunk) > JITERPRETER_NOT_JITTED);
1357+
1358+
JiterpreterCallInfo cinfo;
1359+
cinfo.backward_branch_taken = 0;
1360+
cinfo.bailout_opcode_count = -1;
1361+
1362+
InterpFrame *frame = _frame;
1363+
1364+
ptrdiff_t result = thunk (frame, locals, &cinfo);
1365+
// If a backward branch was taken, we can treat the trace as if it successfully
1366+
// executed at least one time. We don't know how long it actually ran, but back
1367+
// branches are almost always going to be loops. It's fine if a bailout happens
1368+
// after multiple loop iterations.
1369+
if (
1370+
(cinfo.bailout_opcode_count >= 0) &&
1371+
!cinfo.backward_branch_taken &&
1372+
(cinfo.bailout_opcode_count < mono_opt_jiterpreter_trace_monitoring_long_distance)
1373+
) {
1374+
// Start with a penalty of 2 and lerp all the way down to 0
1375+
float scaled = (float)(cinfo.bailout_opcode_count - mono_opt_jiterpreter_trace_monitoring_short_distance)
1376+
/ (mono_opt_jiterpreter_trace_monitoring_long_distance - mono_opt_jiterpreter_trace_monitoring_short_distance);
1377+
int penalty = MIN ((int)((1.0f - scaled) * TRACE_PENALTY_LIMIT), TRACE_PENALTY_LIMIT);
1378+
info->penalty_total += penalty;
1379+
1380+
// g_print ("trace #%d @%d '%s' bailout recorded at opcode #%d, penalty=%d\n", index, ip, frame->imethod->method->name, cinfo.bailout_opcode_count, penalty);
1381+
}
1382+
1383+
gint64 hit_count = info->hit_count++ - mono_opt_jiterpreter_minimum_trace_hit_count;
1384+
if (hit_count == mono_opt_jiterpreter_trace_monitoring_period) {
1385+
// Prepare to enable the trace
1386+
volatile guint16 *mutable_ip = (volatile guint16*)ip;
1387+
*mutable_ip = MINT_TIER_NOP_JITERPRETER;
1388+
1389+
mono_memory_barrier ();
1390+
float average_penalty = info->penalty_total / (float)hit_count / 100.0f,
1391+
threshold = (mono_opt_jiterpreter_trace_monitoring_max_average_penalty / 100.0f);
1392+
1393+
if (average_penalty <= threshold) {
1394+
*(volatile JiterpreterThunk*)(ip + 1) = thunk;
1395+
mono_memory_barrier ();
1396+
*mutable_ip = MINT_TIER_ENTER_JITERPRETER;
1397+
if (mono_opt_jiterpreter_stats_enabled && TRACE_MONITORING_DETAILED)
1398+
g_print ("trace #%d @%d '%s' accepted; average_penalty %f <= %f\n", index, ip, frame->imethod->method->name, average_penalty, threshold);
1399+
} else {
1400+
traces_rejected++;
1401+
if (mono_opt_jiterpreter_stats_enabled) {
1402+
char * full_name = mono_method_get_full_name (frame->imethod->method);
1403+
g_print ("trace #%d @%d '%s' rejected; average_penalty %f > %f\n", index, ip, full_name, average_penalty, threshold);
1404+
g_free (full_name);
1405+
}
1406+
}
1407+
}
1408+
1409+
return result;
1410+
}
1411+
1412+
EMSCRIPTEN_KEEPALIVE gint32
1413+
mono_jiterp_get_rejected_trace_count ()
1414+
{
1415+
return traces_rejected;
1416+
}
1417+
13291418
// HACK: fix C4206
13301419
EMSCRIPTEN_KEEPALIVE
13311420
#endif // HOST_BROWSER

src/mono/mono/mini/interp/jiterpreter.h

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,12 @@
2020
// NOT_JITTED indicates that the trace was not jitted and it should be turned into a NOP
2121
#define JITERPRETER_NOT_JITTED 1
2222

23-
typedef const ptrdiff_t (*JiterpreterThunk) (void *frame, void *pLocals);
23+
typedef struct {
24+
gint32 backward_branch_taken;
25+
gint32 bailout_opcode_count;
26+
} JiterpreterCallInfo;
27+
28+
typedef const ptrdiff_t (*JiterpreterThunk) (void *frame, void *pLocals, JiterpreterCallInfo *cinfo);
2429
typedef void (*WasmJitCallThunk) (void *ret_sp, void *sp, void *ftndesc, gboolean *thrown);
2530
typedef void (*WasmDoJitCall) (gpointer cb, gpointer arg, gboolean *out_thrown);
2631

@@ -139,6 +144,9 @@ mono_jiterp_imethod_to_ftnptr (InterpMethod *imethod);
139144
void
140145
mono_jiterp_enum_hasflag (MonoClass *klass, gint32 *dest, stackval *sp1, stackval *sp2);
141146

147+
ptrdiff_t
148+
mono_jiterp_monitor_trace (const guint16 *ip, void *frame, void *locals);
149+
142150
#endif // __MONO_MINI_INTERPRETER_INTERNALS_H__
143151

144152
extern WasmDoJitCall jiterpreter_do_jit_call;

src/mono/mono/mini/interp/mintops.def

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -840,6 +840,7 @@ OPDEF(MINT_METADATA_UPDATE_LDFLDA, "metadata_update.ldflda", 5, 1, 1, MintOpTwoS
840840
OPDEF(MINT_TIER_PREPARE_JITERPRETER, "tier_prepare_jiterpreter", 3, 0, 0, MintOpInt)
841841
OPDEF(MINT_TIER_NOP_JITERPRETER, "tier_nop_jiterpreter", 3, 0, 0, MintOpInt)
842842
OPDEF(MINT_TIER_ENTER_JITERPRETER, "tier_enter_jiterpreter", 3, 0, 0, MintOpInt)
843+
OPDEF(MINT_TIER_MONITOR_JITERPRETER, "tier_monitor_jiterpreter", 3, 0, 0, MintOpInt)
843844
#endif // HOST_BROWSER
844845

845846
IROPDEF(MINT_NOP, "nop", 1, 0, 0, MintOpNoArgs)

src/mono/mono/utils/options-def.h

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -119,6 +119,14 @@ DEFINE_INT(jiterpreter_minimum_trace_length, "jiterpreter-minimum-trace-length",
119119
DEFINE_INT(jiterpreter_minimum_distance_between_traces, "jiterpreter-minimum-distance-between-traces", 4, "Don't insert entry points closer together than this")
120120
// once a trace entry point is inserted, we only actually JIT code for it once it's been hit this many times
121121
DEFINE_INT(jiterpreter_minimum_trace_hit_count, "jiterpreter-minimum-trace-hit-count", 5000, "JIT trace entry points once they are hit this many times")
122+
// trace prepares turn into a monitor opcode and stay one this long before being converted to enter or nop
123+
DEFINE_INT(jiterpreter_trace_monitoring_period, "jiterpreter-trace-monitoring-period", 1000, "Monitor jitted traces for this many calls to determine whether to keep them")
124+
// traces that process less than this many opcodes have a high exit penalty, more than this have a low exit penalty
125+
DEFINE_INT(jiterpreter_trace_monitoring_short_distance, "jiterpreter-trace-monitoring-short-distance", 4, "Traces that exit after processing this many opcodes have a reduced exit penalty")
126+
// traces that process this many opcodes have no exit penalty
127+
DEFINE_INT(jiterpreter_trace_monitoring_long_distance, "jiterpreter-trace-monitoring-long-distance", 10, "Traces that exit after processing this many opcodes have no exit penalty")
128+
// the average penalty value for a trace is compared against this threshold / 100 to decide whether to discard it
129+
DEFINE_INT(jiterpreter_trace_monitoring_max_average_penalty, "jiterpreter-trace-monitoring-max-average-penalty", 75, "If the average penalty value for a trace is above this value it will be rejected")
122130
// After a do_jit_call call site is hit this many times, we will queue it to be jitted
123131
DEFINE_INT(jiterpreter_jit_call_trampoline_hit_count, "jiterpreter-jit-call-hit-count", 1000, "Queue specialized do_jit_call trampoline for JIT after this many hits")
124132
// After a do_jit_call call site is hit this many times without being jitted, we will flush the JIT queue

src/mono/wasm/runtime/cwraps.ts

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -121,6 +121,7 @@ const fn_signatures: SigLine[] = [
121121
[true, "mono_jiterp_debug_count", "number", []],
122122
[true, "mono_jiterp_get_trace_hit_count", "number", ["number"]],
123123
[true, "mono_jiterp_get_polling_required_address", "number", []],
124+
[true, "mono_jiterp_get_rejected_trace_count", "number", []],
124125
...legacy_interop_cwraps
125126
];
126127

@@ -238,6 +239,7 @@ export interface t_Cwraps {
238239
mono_jiterp_get_trace_hit_count(traceIndex: number): number;
239240
mono_jiterp_get_polling_required_address(): Int32Ptr;
240241
mono_jiterp_write_number_unaligned(destination: VoidPtr, value: number, mode: number): void;
242+
mono_jiterp_get_rejected_trace_count(): number;
241243
}
242244

243245
const wrapped_c_functions: t_Cwraps = <any>{};

src/mono/wasm/runtime/jiterpreter-support.ts

Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1182,6 +1182,17 @@ class Cfg {
11821182
const disp = this.dispatchTable.get(segment.target)!;
11831183
if (this.trace)
11841184
console.log(`backward br from ${(<any>segment.from).toString(16)} to ${(<any>segment.target).toString(16)}: disp=${disp}`);
1185+
1186+
// set the backward branch taken flag in the cinfo so that the monitoring phase
1187+
// knows we took a backward branch. this is unfortunate but unavoidable overhead
1188+
// we just make it a flag instead of an increment to reduce the cost
1189+
this.builder.local("cinfo");
1190+
// TODO: Store the offset in opcodes instead? Probably not useful information
1191+
this.builder.i32_const(1);
1192+
this.builder.appendU8(WasmOpcode.i32_store);
1193+
this.builder.appendMemarg(0, 0); // JiterpreterCallInfo.backward_branch_taken
1194+
1195+
// set the dispatch index for the br_table
11851196
this.builder.i32_const(disp);
11861197
this.builder.local("disp", WasmOpcode.set_local);
11871198
} else {
@@ -1276,6 +1287,24 @@ export function append_bailout (builder: WasmBuilder, ip: MintOpcodePtr, reason:
12761287
builder.appendU8(WasmOpcode.return_);
12771288
}
12781289

1290+
// generate a bailout that is recorded for the monitoring phase as a possible early exit.
1291+
export function append_exit (builder: WasmBuilder, ip: MintOpcodePtr, opcodeCounter: number, reason: BailoutReason) {
1292+
if (opcodeCounter <= (builder.options.monitoringLongDistance + 1)) {
1293+
builder.local("cinfo");
1294+
builder.i32_const(opcodeCounter);
1295+
builder.appendU8(WasmOpcode.i32_store);
1296+
builder.appendMemarg(4, 0); // bailout_opcode_count
1297+
}
1298+
1299+
builder.ip_const(ip);
1300+
if (builder.options.countBailouts) {
1301+
builder.i32_const(builder.base);
1302+
builder.i32_const(reason);
1303+
builder.callImport("bailout");
1304+
}
1305+
builder.appendU8(WasmOpcode.return_);
1306+
}
1307+
12791308
export function copyIntoScratchBuffer (src: NativePointer, size: number) : NativePointer {
12801309
if (!scratchBuffer)
12811310
scratchBuffer = Module._malloc(64);
@@ -1551,6 +1580,10 @@ export type JiterpreterOptions = {
15511580
eliminateNullChecks: boolean;
15521581
minimumTraceLength: number;
15531582
minimumTraceHitCount: number;
1583+
monitoringPeriod: number;
1584+
monitoringShortDistance: number;
1585+
monitoringLongDistance: number;
1586+
monitoringMaxAveragePenalty: number;
15541587
jitCallHitCount: number;
15551588
jitCallFlushThreshold: number;
15561589
interpEntryHitCount: number;
@@ -1577,6 +1610,10 @@ const optionNames : { [jsName: string] : string } = {
15771610
"directJitCalls": "jiterpreter-direct-jit-calls",
15781611
"minimumTraceLength": "jiterpreter-minimum-trace-length",
15791612
"minimumTraceHitCount": "jiterpreter-minimum-trace-hit-count",
1613+
"monitoringPeriod": "jiterpreter-trace-monitoring-period",
1614+
"monitoringShortDistance": "jiterpreter-trace-monitoring-short-distance",
1615+
"monitoringLongDistance": "jiterpreter-trace-monitoring-long-distance",
1616+
"monitoringMaxAveragePenalty": "jiterpreter-trace-monitoring-max-average-penalty",
15801617
"jitCallHitCount": "jiterpreter-jit-call-hit-count",
15811618
"jitCallFlushThreshold": "jiterpreter-jit-call-queue-flush-threshold",
15821619
"interpEntryHitCount": "jiterpreter-interp-entry-hit-count",

0 commit comments

Comments
 (0)