Skip to content

gh-126835: Move constant unaryop & binop folding to CFG #129550

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 35 commits into from
Feb 21, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
35 commits
Select commit Hold shift + click to select a range
04b4a84
move binop folding to cfg
WolframAlph Feb 14, 2025
a6babab
update tests
WolframAlph Feb 14, 2025
0fa7c4c
add match case folding tests
WolframAlph Feb 14, 2025
27b1a09
add peepholer tests
WolframAlph Feb 15, 2025
d3df8c8
polish ast code
WolframAlph Feb 15, 2025
e82841a
add assertions
WolframAlph Feb 15, 2025
cd6fbd7
use IS_NUMERIC_CONST_EXPR
WolframAlph Feb 15, 2025
300976e
bring back tests
WolframAlph Feb 15, 2025
3240d8e
move unaryop to cfg
WolframAlph Feb 15, 2025
74d2275
add tests
WolframAlph Feb 15, 2025
8358e28
cancel out unary not
WolframAlph Feb 15, 2025
c1a1be5
fix optimize_if_const_unaryop
WolframAlph Feb 15, 2025
ef9221a
add test_optimize_unary_not
WolframAlph Feb 15, 2025
598adce
add optimize_unary_not_non_const
WolframAlph Feb 16, 2025
bb60d4c
add tests
WolframAlph Feb 16, 2025
eb99870
add tests
WolframAlph Feb 16, 2025
c93627e
add static
WolframAlph Feb 16, 2025
f0f044a
polish
WolframAlph Feb 16, 2025
a02ac66
add unaryop folding tests
WolframAlph Feb 16, 2025
2739fa0
add not not test to catch misoptimized case
WolframAlph Feb 17, 2025
258a5b6
revert old unarynot handing, add contains/is + unarynot folding
WolframAlph Feb 17, 2025
59ee897
address reviews
WolframAlph Feb 17, 2025
91ea2fa
address reviews
WolframAlph Feb 17, 2025
2c5ee86
simplify optimize_if_const_unaryop
WolframAlph Feb 17, 2025
0399fce
address reviews
WolframAlph Feb 19, 2025
3c53923
simplify instr_make_load_const
WolframAlph Feb 19, 2025
a6a06c8
address review for tests
WolframAlph Feb 20, 2025
2ea8425
replace macros
WolframAlph Feb 20, 2025
7c9c69b
update peepholer test
WolframAlph Feb 20, 2025
099afba
define is_unarynegative_const_complex_expr & is_allowed_match_case_bi…
WolframAlph Feb 20, 2025
0dad7c1
try to fold match case expression without checks
WolframAlph Feb 20, 2025
1e8b552
simplify folding
WolframAlph Feb 20, 2025
f4e9a42
address review
WolframAlph Feb 20, 2025
2dba23f
minor adjustments
WolframAlph Feb 21, 2025
cafbc61
Merge branch 'main' into fold-binop-cfg
iritkatriel Feb 21, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
198 changes: 80 additions & 118 deletions Lib/test/test_ast/test_ast.py
Original file line number Diff line number Diff line change
Expand Up @@ -154,18 +154,17 @@ def test_optimization_levels__debug__(self):
self.assertEqual(res.body[0].value.id, expected)

def test_optimization_levels_const_folding(self):
folded = ('Expr', (1, 0, 1, 5), ('Constant', (1, 0, 1, 5), 3, None))
not_folded = ('Expr', (1, 0, 1, 5),
('BinOp', (1, 0, 1, 5),
('Constant', (1, 0, 1, 1), 1, None),
('Add',),
('Constant', (1, 4, 1, 5), 2, None)))
folded = ('Expr', (1, 0, 1, 6), ('Constant', (1, 0, 1, 6), (1, 2), None))
not_folded = ('Expr', (1, 0, 1, 6),
('Tuple', (1, 0, 1, 6),
[('Constant', (1, 1, 1, 2), 1, None),
('Constant', (1, 4, 1, 5), 2, None)], ('Load',)))

cases = [(-1, not_folded), (0, not_folded), (1, folded), (2, folded)]
for (optval, expected) in cases:
with self.subTest(optval=optval):
tree1 = ast.parse("1 + 2", optimize=optval)
tree2 = ast.parse(ast.parse("1 + 2"), optimize=optval)
tree1 = ast.parse("(1, 2)", optimize=optval)
tree2 = ast.parse(ast.parse("(1, 2)"), optimize=optval)
for tree in [tree1, tree2]:
res = to_tuple(tree.body[0])
self.assertEqual(res, expected)
Expand Down Expand Up @@ -3089,27 +3088,6 @@ def test_cli_file_input(self):


class ASTOptimiziationTests(unittest.TestCase):
binop = {
"+": ast.Add(),
"-": ast.Sub(),
"*": ast.Mult(),
"/": ast.Div(),
"%": ast.Mod(),
"<<": ast.LShift(),
">>": ast.RShift(),
"|": ast.BitOr(),
"^": ast.BitXor(),
"&": ast.BitAnd(),
"//": ast.FloorDiv(),
"**": ast.Pow(),
}

unaryop = {
"~": ast.Invert(),
"+": ast.UAdd(),
"-": ast.USub(),
}

def wrap_expr(self, expr):
return ast.Module(body=[ast.Expr(value=expr)])

Expand Down Expand Up @@ -3141,83 +3119,6 @@ def assert_ast(self, code, non_optimized_target, optimized_target):
f"{ast.dump(optimized_tree)}",
)

def create_binop(self, operand, left=ast.Constant(1), right=ast.Constant(1)):
return ast.BinOp(left=left, op=self.binop[operand], right=right)

def test_folding_binop(self):
code = "1 %s 1"
operators = self.binop.keys()

for op in operators:
result_code = code % op
non_optimized_target = self.wrap_expr(self.create_binop(op))
optimized_target = self.wrap_expr(ast.Constant(value=eval(result_code)))

with self.subTest(
result_code=result_code,
non_optimized_target=non_optimized_target,
optimized_target=optimized_target
):
self.assert_ast(result_code, non_optimized_target, optimized_target)

# Multiplication of constant tuples must be folded
code = "(1,) * 3"
non_optimized_target = self.wrap_expr(self.create_binop("*", ast.Tuple(elts=[ast.Constant(value=1)]), ast.Constant(value=3)))
optimized_target = self.wrap_expr(ast.Constant(eval(code)))

self.assert_ast(code, non_optimized_target, optimized_target)

def test_folding_unaryop(self):
code = "%s1"
operators = self.unaryop.keys()

def create_unaryop(operand):
return ast.UnaryOp(op=self.unaryop[operand], operand=ast.Constant(1))

for op in operators:
result_code = code % op
non_optimized_target = self.wrap_expr(create_unaryop(op))
optimized_target = self.wrap_expr(ast.Constant(eval(result_code)))

with self.subTest(
result_code=result_code,
non_optimized_target=non_optimized_target,
optimized_target=optimized_target
):
self.assert_ast(result_code, non_optimized_target, optimized_target)

def test_folding_not(self):
code = "not (1 %s (1,))"
operators = {
"in": ast.In(),
"is": ast.Is(),
}
opt_operators = {
"is": ast.IsNot(),
"in": ast.NotIn(),
}

def create_notop(operand):
return ast.UnaryOp(op=ast.Not(), operand=ast.Compare(
left=ast.Constant(value=1),
ops=[operators[operand]],
comparators=[ast.Tuple(elts=[ast.Constant(value=1)])]
))

for op in operators.keys():
result_code = code % op
non_optimized_target = self.wrap_expr(create_notop(op))
optimized_target = self.wrap_expr(
ast.Compare(left=ast.Constant(1), ops=[opt_operators[op]], comparators=[ast.Constant(value=(1,))])
)

with self.subTest(
result_code=result_code,
non_optimized_target=non_optimized_target,
optimized_target=optimized_target
):
self.assert_ast(result_code, non_optimized_target, optimized_target)

def test_folding_format(self):
code = "'%s' % (a,)"

Expand Down Expand Up @@ -3247,9 +3148,9 @@ def test_folding_tuple(self):
self.assert_ast(code, non_optimized_target, optimized_target)

def test_folding_type_param_in_function_def(self):
code = "def foo[%s = 1 + 1](): pass"
code = "def foo[%s = (1, 2)](): pass"

unoptimized_binop = self.create_binop("+")
unoptimized_tuple = ast.Tuple(elts=[ast.Constant(1), ast.Constant(2)])
unoptimized_type_params = [
("T", "T", ast.TypeVar),
("**P", "P", ast.ParamSpec),
Expand All @@ -3263,23 +3164,23 @@ def test_folding_type_param_in_function_def(self):
name='foo',
args=ast.arguments(),
body=[ast.Pass()],
type_params=[type_param(name=name, default_value=ast.Constant(2))]
type_params=[type_param(name=name, default_value=ast.Constant((1, 2)))]
)
)
non_optimized_target = self.wrap_statement(
ast.FunctionDef(
name='foo',
args=ast.arguments(),
body=[ast.Pass()],
type_params=[type_param(name=name, default_value=unoptimized_binop)]
type_params=[type_param(name=name, default_value=unoptimized_tuple)]
)
)
self.assert_ast(result_code, non_optimized_target, optimized_target)

def test_folding_type_param_in_class_def(self):
code = "class foo[%s = 1 + 1]: pass"
code = "class foo[%s = (1, 2)]: pass"

unoptimized_binop = self.create_binop("+")
unoptimized_tuple = ast.Tuple(elts=[ast.Constant(1), ast.Constant(2)])
unoptimized_type_params = [
("T", "T", ast.TypeVar),
("**P", "P", ast.ParamSpec),
Expand All @@ -3292,22 +3193,22 @@ def test_folding_type_param_in_class_def(self):
ast.ClassDef(
name='foo',
body=[ast.Pass()],
type_params=[type_param(name=name, default_value=ast.Constant(2))]
type_params=[type_param(name=name, default_value=ast.Constant((1, 2)))]
)
)
non_optimized_target = self.wrap_statement(
ast.ClassDef(
name='foo',
body=[ast.Pass()],
type_params=[type_param(name=name, default_value=unoptimized_binop)]
type_params=[type_param(name=name, default_value=unoptimized_tuple)]
)
)
self.assert_ast(result_code, non_optimized_target, optimized_target)

def test_folding_type_param_in_type_alias(self):
code = "type foo[%s = 1 + 1] = 1"
code = "type foo[%s = (1, 2)] = 1"

unoptimized_binop = self.create_binop("+")
unoptimized_tuple = ast.Tuple(elts=[ast.Constant(1), ast.Constant(2)])
unoptimized_type_params = [
("T", "T", ast.TypeVar),
("**P", "P", ast.ParamSpec),
Expand All @@ -3319,19 +3220,80 @@ def test_folding_type_param_in_type_alias(self):
optimized_target = self.wrap_statement(
ast.TypeAlias(
name=ast.Name(id='foo', ctx=ast.Store()),
type_params=[type_param(name=name, default_value=ast.Constant(2))],
type_params=[type_param(name=name, default_value=ast.Constant((1, 2)))],
value=ast.Constant(value=1),
)
)
non_optimized_target = self.wrap_statement(
ast.TypeAlias(
name=ast.Name(id='foo', ctx=ast.Store()),
type_params=[type_param(name=name, default_value=unoptimized_binop)],
type_params=[type_param(name=name, default_value=unoptimized_tuple)],
value=ast.Constant(value=1),
)
)
self.assert_ast(result_code, non_optimized_target, optimized_target)

def test_folding_match_case_allowed_expressions(self):
def get_match_case_values(node):
result = []
if isinstance(node, ast.Constant):
result.append(node.value)
elif isinstance(node, ast.MatchValue):
result.extend(get_match_case_values(node.value))
elif isinstance(node, ast.MatchMapping):
for key in node.keys:
result.extend(get_match_case_values(key))
elif isinstance(node, ast.MatchSequence):
for pat in node.patterns:
result.extend(get_match_case_values(pat))
else:
self.fail(f"Unexpected node {node}")
return result

tests = [
("-0", [0]),
("-0.1", [-0.1]),
("-0j", [complex(0, 0)]),
("-0.1j", [complex(0, -0.1)]),
("1 + 2j", [complex(1, 2)]),
("1 - 2j", [complex(1, -2)]),
("1.1 + 2.1j", [complex(1.1, 2.1)]),
("1.1 - 2.1j", [complex(1.1, -2.1)]),
("-0 + 1j", [complex(0, 1)]),
("-0 - 1j", [complex(0, -1)]),
("-0.1 + 1.1j", [complex(-0.1, 1.1)]),
("-0.1 - 1.1j", [complex(-0.1, -1.1)]),
("{-0: 0}", [0]),
("{-0.1: 0}", [-0.1]),
("{-0j: 0}", [complex(0, 0)]),
("{-0.1j: 0}", [complex(0, -0.1)]),
("{1 + 2j: 0}", [complex(1, 2)]),
("{1 - 2j: 0}", [complex(1, -2)]),
("{1.1 + 2.1j: 0}", [complex(1.1, 2.1)]),
("{1.1 - 2.1j: 0}", [complex(1.1, -2.1)]),
("{-0 + 1j: 0}", [complex(0, 1)]),
("{-0 - 1j: 0}", [complex(0, -1)]),
("{-0.1 + 1.1j: 0}", [complex(-0.1, 1.1)]),
("{-0.1 - 1.1j: 0}", [complex(-0.1, -1.1)]),
("{-0: 0, 0 + 1j: 0, 0.1 + 1j: 0}", [0, complex(0, 1), complex(0.1, 1)]),
("[-0, -0.1, -0j, -0.1j]", [0, -0.1, complex(0, 0), complex(0, -0.1)]),
("[[[[-0, -0.1, -0j, -0.1j]]]]", [0, -0.1, complex(0, 0), complex(0, -0.1)]),
("[[-0, -0.1], -0j, -0.1j]", [0, -0.1, complex(0, 0), complex(0, -0.1)]),
("[[-0, -0.1], [-0j, -0.1j]]", [0, -0.1, complex(0, 0), complex(0, -0.1)]),
("(-0, -0.1, -0j, -0.1j)", [0, -0.1, complex(0, 0), complex(0, -0.1)]),
("((((-0, -0.1, -0j, -0.1j))))", [0, -0.1, complex(0, 0), complex(0, -0.1)]),
("((-0, -0.1), -0j, -0.1j)", [0, -0.1, complex(0, 0), complex(0, -0.1)]),
("((-0, -0.1), (-0j, -0.1j))", [0, -0.1, complex(0, 0), complex(0, -0.1)]),
]
for match_expr, constants in tests:
with self.subTest(match_expr):
src = f"match 0:\n\t case {match_expr}: pass"
tree = ast.parse(src, optimize=1)
match_stmt = tree.body[0]
case = match_stmt.cases[0]
values = get_match_case_values(case.pattern)
self.assertListEqual(constants, values)


if __name__ == '__main__':
if len(sys.argv) > 1 and sys.argv[1] == '--snapshot-update':
Expand Down
2 changes: 1 addition & 1 deletion Lib/test/test_ast/utils.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
def to_tuple(t):
if t is None or isinstance(t, (str, int, complex, float, bytes)) or t is Ellipsis:
if t is None or isinstance(t, (str, int, complex, float, bytes, tuple)) or t is Ellipsis:
return t
elif isinstance(t, list):
return [to_tuple(e) for e in t]
Expand Down
15 changes: 6 additions & 9 deletions Lib/test/test_builtin.py
Original file line number Diff line number Diff line change
Expand Up @@ -555,7 +555,7 @@ def test_compile_async_generator(self):
self.assertEqual(type(glob['ticker']()), AsyncGeneratorType)

def test_compile_ast(self):
args = ("a*(1+2)", "f.py", "exec")
args = ("a*(1,2)", "f.py", "exec")
raw = compile(*args, flags = ast.PyCF_ONLY_AST).body[0]
opt1 = compile(*args, flags = ast.PyCF_OPTIMIZED_AST).body[0]
opt2 = compile(ast.parse(args[0]), *args[1:], flags = ast.PyCF_OPTIMIZED_AST).body[0]
Expand All @@ -566,17 +566,14 @@ def test_compile_ast(self):
self.assertIsInstance(tree.value.left, ast.Name)
self.assertEqual(tree.value.left.id, 'a')

raw_right = raw.value.right # expect BinOp(1, '+', 2)
self.assertIsInstance(raw_right, ast.BinOp)
self.assertIsInstance(raw_right.left, ast.Constant)
self.assertEqual(raw_right.left.value, 1)
self.assertIsInstance(raw_right.right, ast.Constant)
self.assertEqual(raw_right.right.value, 2)
raw_right = raw.value.right # expect Tuple((1, 2))
self.assertIsInstance(raw_right, ast.Tuple)
self.assertListEqual([elt.value for elt in raw_right.elts], [1, 2])

for opt in [opt1, opt2]:
opt_right = opt.value.right # expect Constant(3)
opt_right = opt.value.right # expect Constant((1,2))
self.assertIsInstance(opt_right, ast.Constant)
self.assertEqual(opt_right.value, 3)
self.assertEqual(opt_right.value, (1, 2))

def test_delattr(self):
sys.spam = 1
Expand Down
Loading
Loading