Skip to content

Commit 9e488c3

Browse files
committed
Fix all() unroll for non-generators/non-list comprehensions
Fix #5358
1 parent e4fe41e commit 9e488c3

File tree

3 files changed

+24
-6
lines changed

3 files changed

+24
-6
lines changed

changelog/5358.bugfix.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
Fix assertion rewriting of ``all()`` calls to deal with non-generators.

src/_pytest/assertion/rewrite.py

Lines changed: 10 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -949,11 +949,19 @@ def visit_BinOp(self, binop):
949949
res = self.assign(ast.BinOp(left_expr, binop.op, right_expr))
950950
return res, explanation
951951

952+
def _is_any_call_with_generator_or_list_comprehension(self, call):
953+
"""Return True if the Call node is an 'any' call with a generator or list comprehension"""
954+
return (
955+
isinstance(call.func, ast.Name)
956+
and call.func.id == "all"
957+
and isinstance(call.args[0], (ast.GeneratorExp, ast.ListComp))
958+
)
959+
952960
def visit_Call_35(self, call):
953961
"""
954962
visit `ast.Call` nodes on Python3.5 and after
955963
"""
956-
if isinstance(call.func, ast.Name) and call.func.id == "all":
964+
if self._is_any_call_with_generator_or_list_comprehension(call):
957965
return self._visit_all(call)
958966
new_func, func_expl = self.visit(call.func)
959967
arg_expls = []
@@ -980,8 +988,6 @@ def visit_Call_35(self, call):
980988

981989
def _visit_all(self, call):
982990
"""Special rewrite for the builtin all function, see #5062"""
983-
if not isinstance(call.args[0], (ast.GeneratorExp, ast.ListComp)):
984-
return
985991
gen_exp = call.args[0]
986992
assertion_module = ast.Module(
987993
body=[ast.Assert(test=gen_exp.elt, lineno=1, msg="", col_offset=1)]
@@ -1009,7 +1015,7 @@ def visit_Call_legacy(self, call):
10091015
"""
10101016
visit `ast.Call nodes on 3.4 and below`
10111017
"""
1012-
if isinstance(call.func, ast.Name) and call.func.id == "all":
1018+
if self._is_any_call_with_generator_or_list_comprehension(call):
10131019
return self._visit_all(call)
10141020
new_func, func_expl = self.visit(call.func)
10151021
arg_expls = []

testing/test_assertrewrite.py

Lines changed: 13 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -677,7 +677,7 @@ def __repr__(self):
677677
assert "UnicodeDecodeError" not in msg
678678
assert "UnicodeEncodeError" not in msg
679679

680-
def test_unroll_generator(self, testdir):
680+
def test_unroll_all_generator(self, testdir):
681681
testdir.makepyfile(
682682
"""
683683
def check_even(num):
@@ -692,7 +692,7 @@ def test_generator():
692692
result = testdir.runpytest()
693693
result.stdout.fnmatch_lines(["*assert False*", "*where False = check_even(1)*"])
694694

695-
def test_unroll_list_comprehension(self, testdir):
695+
def test_unroll_all_list_comprehension(self, testdir):
696696
testdir.makepyfile(
697697
"""
698698
def check_even(num):
@@ -707,6 +707,17 @@ def test_list_comprehension():
707707
result = testdir.runpytest()
708708
result.stdout.fnmatch_lines(["*assert False*", "*where False = check_even(1)*"])
709709

710+
def test_unroll_all_object(self, testdir):
711+
"""all() for non generators/non list-comprehensions (#5358)"""
712+
testdir.makepyfile(
713+
"""
714+
def test():
715+
assert all((1, 0))
716+
"""
717+
)
718+
result = testdir.runpytest()
719+
result.stdout.fnmatch_lines(["*assert False*", "*where False = all((1, 0))*"])
720+
710721
def test_for_loop(self, testdir):
711722
testdir.makepyfile(
712723
"""

0 commit comments

Comments
 (0)