Skip to content

GH-111848: Set the IP when de-optimizing #112065

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 14 commits into from
Nov 15, 2023
104 changes: 52 additions & 52 deletions Include/internal/pycore_opcode_metadata.h

Large diffs are not rendered by default.

5 changes: 3 additions & 2 deletions Include/internal/pycore_uops.h
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,9 @@ extern "C" {
#define _Py_UOP_MAX_TRACE_LENGTH 128

typedef struct {
uint32_t opcode;
uint32_t oparg;
uint16_t opcode;
uint16_t oparg;
uint32_t target;
uint64_t operand; // A cache entry
} _PyUOpInstruction;

Expand Down
2 changes: 2 additions & 0 deletions Python/ceval.c
Original file line number Diff line number Diff line change
Expand Up @@ -1067,6 +1067,7 @@ _PyEval_EvalFrameDefault(PyThreadState *tstate, _PyInterpreterFrame *frame, int
UOP_STAT_INC(opcode, miss);
frame->return_offset = 0; // Dispatch to frame->instr_ptr
_PyFrame_SetStackPointer(frame, stack_pointer);
frame->instr_ptr = next_uop[-1].target + _PyCode_CODE((PyCodeObject *)frame->f_executable);
Py_DECREF(current_executor);
// Fall through
// Jump here from ENTER_EXECUTOR
Expand All @@ -1077,6 +1078,7 @@ _PyEval_EvalFrameDefault(PyThreadState *tstate, _PyInterpreterFrame *frame, int
// Jump here from _EXIT_TRACE
exit_trace:
_PyFrame_SetStackPointer(frame, stack_pointer);
frame->instr_ptr = next_uop[-1].target + _PyCode_CODE((PyCodeObject *)frame->f_executable);
Py_DECREF(current_executor);
OPT_HIST(trace_uop_execution_counter, trace_run_length_hist);
goto enter_tier_one;
Expand Down
48 changes: 22 additions & 26 deletions Python/optimizer.c
Original file line number Diff line number Diff line change
Expand Up @@ -446,7 +446,8 @@ translate_bytecode_to_trace(
#define DPRINTF(level, ...)
#endif

#define ADD_TO_TRACE(OPCODE, OPARG, OPERAND) \

#define ADD_TO_TRACE(OPCODE, OPARG, OPERAND, TARGET) \
DPRINTF(2, \
" ADD_TO_TRACE(%s, %d, %" PRIu64 ")\n", \
uop_name(OPCODE), \
Expand All @@ -458,23 +459,12 @@ translate_bytecode_to_trace(
trace[trace_length].opcode = (OPCODE); \
trace[trace_length].oparg = (OPARG); \
trace[trace_length].operand = (OPERAND); \
trace[trace_length].target = (TARGET); \
trace_length++;

#define INSTR_IP(INSTR, CODE) \
((uint32_t)((INSTR) - ((_Py_CODEUNIT *)(CODE)->co_code_adaptive)))

#define ADD_TO_STUB(INDEX, OPCODE, OPARG, OPERAND) \
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah, no more stubs. Do you expect stubs to eventually make a comeback? If not, there are a few mentions of 'stub' that can be removed.

Notably the stub arg to RESERVE() is now always 0. We could remove it.

DPRINTF(2, " ADD_TO_STUB(%d, %s, %d, %" PRIu64 ")\n", \
(INDEX), \
uop_name(OPCODE), \
(OPARG), \
(uint64_t)(OPERAND)); \
assert(reserved > 0); \
reserved--; \
trace[(INDEX)].opcode = (OPCODE); \
trace[(INDEX)].oparg = (OPARG); \
trace[(INDEX)].operand = (OPERAND);

// Reserve space for n uops
#define RESERVE_RAW(n, opname) \
if (trace_length + (n) > max_length) { \
Expand All @@ -483,7 +473,7 @@ translate_bytecode_to_trace(
OPT_STAT_INC(trace_too_long); \
goto done; \
} \
reserved = (n); // Keep ADD_TO_TRACE / ADD_TO_STUB honest
reserved = (n); // Keep ADD_TO_TRACE honest

// Reserve space for main+stub uops, plus 3 for _SET_IP, _CHECK_VALIDITY and _EXIT_TRACE
#define RESERVE(main, stub) RESERVE_RAW((main) + (stub) + 3, uop_name(opcode))
Expand All @@ -493,7 +483,7 @@ translate_bytecode_to_trace(
if (trace_stack_depth >= TRACE_STACK_SIZE) { \
DPRINTF(2, "Trace stack overflow\n"); \
OPT_STAT_INC(trace_stack_overflow); \
ADD_TO_TRACE(_SET_IP, 0, 0); \
ADD_TO_TRACE(_EXIT_TRACE, 0, 0, 0); \
goto done; \
} \
trace_stack[trace_stack_depth].code = code; \
Expand All @@ -513,22 +503,28 @@ translate_bytecode_to_trace(
PyUnicode_AsUTF8(code->co_filename),
code->co_firstlineno,
2 * INSTR_IP(initial_instr, code));

uint32_t target = 0;
top: // Jump here after _PUSH_FRAME or likely branches
for (;;) {
target = INSTR_IP(instr, code);
RESERVE_RAW(3, "epilogue"); // Always need space for _SET_IP, _CHECK_VALIDITY and _EXIT_TRACE
ADD_TO_TRACE(_SET_IP, INSTR_IP(instr, code), 0);
ADD_TO_TRACE(_CHECK_VALIDITY, 0, 0);
ADD_TO_TRACE(_SET_IP, target, 0, target);
ADD_TO_TRACE(_CHECK_VALIDITY, 0, 0, target);

uint32_t opcode = instr->op.code;
uint32_t oparg = instr->op.arg;
uint32_t extras = 0;

while (opcode == EXTENDED_ARG) {

if (opcode == EXTENDED_ARG) {
instr++;
extras += 1;
opcode = instr->op.code;
oparg = (oparg << 8) | instr->op.arg;
if (opcode == EXTENDED_ARG) {
instr--;
goto done;
}
}

if (opcode == ENTER_EXECUTOR) {
Expand All @@ -554,7 +550,7 @@ translate_bytecode_to_trace(
DPRINTF(4, "%s(%d): counter=%x, bitcount=%d, likely=%d, uopcode=%s\n",
uop_name(opcode), oparg,
counter, bitcount, jump_likely, uop_name(uopcode));
ADD_TO_TRACE(uopcode, max_length, 0);
ADD_TO_TRACE(uopcode, max_length, 0, target);
if (jump_likely) {
_Py_CODEUNIT *target_instr = next_instr + oparg;
DPRINTF(2, "Jump likely (%x = %d bits), continue at byte offset %d\n",
Expand All @@ -569,7 +565,7 @@ translate_bytecode_to_trace(
{
if (instr + 2 - oparg == initial_instr && code == initial_code) {
RESERVE(1, 0);
ADD_TO_TRACE(_JUMP_TO_TOP, 0, 0);
ADD_TO_TRACE(_JUMP_TO_TOP, 0, 0, 0);
}
else {
OPT_STAT_INC(inner_loop);
Expand Down Expand Up @@ -653,7 +649,7 @@ translate_bytecode_to_trace(
expansion->uops[i].offset);
Py_FatalError("garbled expansion");
}
ADD_TO_TRACE(uop, oparg, operand);
ADD_TO_TRACE(uop, oparg, operand, target);
if (uop == _POP_FRAME) {
TRACE_STACK_POP();
DPRINTF(2,
Expand Down Expand Up @@ -682,15 +678,15 @@ translate_bytecode_to_trace(
PyUnicode_AsUTF8(new_code->co_filename),
new_code->co_firstlineno);
OPT_STAT_INC(recursive_call);
ADD_TO_TRACE(_SET_IP, 0, 0);
ADD_TO_TRACE(_EXIT_TRACE, 0, 0, 0);
goto done;
}
if (new_code->co_version != func_version) {
// func.__code__ was updated.
// Perhaps it may happen again, so don't bother tracing.
// TODO: Reason about this -- is it better to bail or not?
DPRINTF(2, "Bailing because co_version != func_version\n");
ADD_TO_TRACE(_SET_IP, 0, 0);
ADD_TO_TRACE(_EXIT_TRACE, 0, 0, 0);
goto done;
}
// Increment IP to the return address
Expand All @@ -707,7 +703,7 @@ translate_bytecode_to_trace(
2 * INSTR_IP(instr, code));
goto top;
}
ADD_TO_TRACE(_SET_IP, 0, 0);
ADD_TO_TRACE(_EXIT_TRACE, 0, 0, 0);
goto done;
}
}
Expand All @@ -732,7 +728,7 @@ translate_bytecode_to_trace(
assert(code == initial_code);
// Skip short traces like _SET_IP, LOAD_FAST, _SET_IP, _EXIT_TRACE
if (trace_length > 4) {
ADD_TO_TRACE(_EXIT_TRACE, 0, 0);
ADD_TO_TRACE(_EXIT_TRACE, 0, 0, target);
DPRINTF(1,
"Created a trace for %s (%s:%d) at byte offset %d -- length %d+%d\n",
PyUnicode_AsUTF8(code->co_qualname),
Expand Down
22 changes: 10 additions & 12 deletions Python/optimizer_analysis.c
Original file line number Diff line number Diff line change
Expand Up @@ -17,21 +17,15 @@ remove_unneeded_uops(_PyUOpInstruction *buffer, int buffer_size)
{
// Note that we don't enter stubs, those SET_IPs are needed.
int last_set_ip = -1;
bool need_ip = true;
bool maybe_invalid = false;
for (int pc = 0; pc < buffer_size; pc++) {
int opcode = buffer[pc].opcode;
if (opcode == _SET_IP) {
if (!need_ip && last_set_ip >= 0) {
buffer[last_set_ip].opcode = NOP;
}
need_ip = false;
buffer[pc].opcode = NOP;
last_set_ip = pc;
}
else if (opcode == _CHECK_VALIDITY) {
if (maybe_invalid) {
/* Exiting the trace requires that IP is correct */
need_ip = true;
maybe_invalid = false;
}
else {
Expand All @@ -42,12 +36,16 @@ remove_unneeded_uops(_PyUOpInstruction *buffer, int buffer_size)
break;
}
else {
// If opcode has ERROR or DEOPT, set need_ip to true
if (_PyOpcode_opcode_metadata[opcode].flags & (HAS_ERROR_FLAG | HAS_DEOPT_FLAG) || opcode == _PUSH_FRAME) {
need_ip = true;
}
if (_PyOpcode_opcode_metadata[opcode].flags & HAS_ESCAPES_FLAG) {
if (OPCODE_HAS_ESCAPES(opcode)) {
maybe_invalid = true;
if (last_set_ip >= 0) {
buffer[last_set_ip].opcode = _SET_IP;
}
}
if (OPCODE_HAS_ERROR(opcode) || opcode == _PUSH_FRAME) {
if (last_set_ip >= 0) {
buffer[last_set_ip].opcode = _SET_IP;
}
}
}
}
Expand Down
35 changes: 29 additions & 6 deletions Tools/cases_generator/flags.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
import parsing
from typing import AbstractSet

WHITELIST = (
NON_ESCAPING_FUNCTIONS = (
"Py_INCREF",
"_PyDictOrValues_IsValues",
"_PyObject_DictOrValuesPointer",
Expand All @@ -31,9 +31,29 @@
"_PyLong_IsNonNegativeCompact",
"_PyLong_CompactValue",
"_Py_NewRef",
"_Py_IsImmortal",
"_Py_STR",
"_PyLong_Add",
"_PyLong_Multiply",
"_PyLong_Subtract",
"Py_NewRef",
"_PyList_ITEMS",
"_PyTuple_ITEMS",
"_PyList_AppendTakeRef",
"_Py_atomic_load_uintptr_relaxed",
"_PyFrame_GetCode",
"_PyThreadState_HasStackSpace",
)

def makes_escaping_api_call(instr: parsing.Node) -> bool:
ESCAPING_FUNCTIONS = (
"import_name",
"import_from",
)


def makes_escaping_api_call(instr: parsing.InstDef) -> bool:
if "CALL_INTRINSIC" in instr.name:
return True;
tkns = iter(instr.tokens)
for tkn in tkns:
if tkn.kind != lx.IDENTIFIER:
Expand All @@ -44,13 +64,17 @@ def makes_escaping_api_call(instr: parsing.Node) -> bool:
return False
if next_tkn.kind != lx.LPAREN:
continue
if tkn.text in ESCAPING_FUNCTIONS:
return True
if not tkn.text.startswith("Py") and not tkn.text.startswith("_Py"):
continue
if tkn.text.endswith("Check"):
continue
if tkn.text.startswith("Py_Is"):
continue
if tkn.text.endswith("CheckExact"):
continue
if tkn.text in WHITELIST:
if tkn.text in NON_ESCAPING_FUNCTIONS:
continue
return True
return False
Expand All @@ -74,7 +98,7 @@ def __post_init__(self) -> None:
self.bitmask = {name: (1 << i) for i, name in enumerate(self.names())}

@staticmethod
def fromInstruction(instr: parsing.Node) -> "InstructionFlags":
def fromInstruction(instr: parsing.InstDef) -> "InstructionFlags":
has_free = (
variable_used(instr, "PyCell_New")
or variable_used(instr, "PyCell_GET")
Expand All @@ -101,8 +125,7 @@ def fromInstruction(instr: parsing.Node) -> "InstructionFlags":
or variable_used(instr, "resume_with_error")
),
HAS_ESCAPES_FLAG=(
variable_used(instr, "tstate")
or makes_escaping_api_call(instr)
makes_escaping_api_call(instr)
),
)

Expand Down