Skip to content

Commit 7c823af

Browse files
mimre25cooperlees
andauthored
feat(rules): Add rule to check for mutations of loop iterator (#446)
* feat(rules): Add rule to check for mutations of loop iterator * fixup! feat(rules): Add rule to check for mutations of loop iterator * doc(B038): Add doc for B038 to README.rst --------- Co-authored-by: Cooper Lees <[email protected]>
1 parent 6c96f75 commit 7c823af

File tree

4 files changed

+129
-0
lines changed

4 files changed

+129
-0
lines changed

README.rst

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -198,6 +198,10 @@ second usage. Save the result to a list if the result is needed multiple times.
198198
**B036**: Found ``except BaseException:`` without re-raising (no ``raise`` in the top-level of the ``except`` block). This catches all kinds of things (Exception, SystemExit, KeyboardInterrupt...) and may prevent a program from exiting as expected.
199199

200200
**B037**: Found ``return <value>``, ``yield``, ``yield <value>``, or ``yield from <value>`` in class ``__init__()`` method. No values should be returned or yielded, only bare ``return``s are ok.
201+
202+
**B038**: Found a mutation of a mutable loop iterable inside the loop body.
203+
Changes to the iterable of a loop such as calls to `list.remove()` or via `del` can cause unintended bugs.
204+
201205
Opinionated warnings
202206
~~~~~~~~~~~~~~~~~~~~
203207

bugbear.py

Lines changed: 63 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -237,6 +237,7 @@ def _flatten_excepthandler(node):
237237
):
238238
expr_list.extend(expr.value.elts)
239239
continue
240+
240241
yield expr
241242

242243

@@ -521,6 +522,7 @@ def visit_For(self, node):
521522
self.check_for_b020(node)
522523
self.check_for_b023(node)
523524
self.check_for_b031(node)
525+
self.check_for_b038(node)
524526
self.generic_visit(node)
525527

526528
def visit_AsyncFor(self, node):
@@ -1570,6 +1572,18 @@ def check(num_args, param_name):
15701572
elif node.func.attr == "split":
15711573
check(2, "maxsplit")
15721574

1575+
def check_for_b038(self, node: ast.For):
1576+
if isinstance(node.iter, ast.Name):
1577+
name = _to_name_str(node.iter)
1578+
elif isinstance(node.iter, ast.Attribute):
1579+
name = _to_name_str(node.iter)
1580+
else:
1581+
return
1582+
checker = B038Checker(name)
1583+
checker.visit(node.body)
1584+
for mutation in checker.mutations:
1585+
self.errors.append(B038(mutation.lineno, mutation.col_offset))
1586+
15731587

15741588
def compose_call_path(node):
15751589
if isinstance(node, ast.Attribute):
@@ -1581,6 +1595,49 @@ def compose_call_path(node):
15811595
yield node.id
15821596

15831597

1598+
class B038Checker(ast.NodeVisitor):
1599+
def __init__(self, name: str):
1600+
self.name = name
1601+
self.mutations = []
1602+
1603+
def visit_Delete(self, node: ast.Delete):
1604+
for target in node.targets:
1605+
if isinstance(target, ast.Subscript):
1606+
name = _to_name_str(target.value)
1607+
elif isinstance(target, (ast.Attribute, ast.Name)):
1608+
name = _to_name_str(target)
1609+
else:
1610+
name = "" # fallback
1611+
self.generic_visit(target)
1612+
1613+
if name == self.name:
1614+
self.mutations.append(node)
1615+
1616+
def visit_Call(self, node: ast.Call):
1617+
if isinstance(node.func, ast.Attribute):
1618+
name = _to_name_str(node.func.value)
1619+
function_object = name
1620+
1621+
if function_object == self.name:
1622+
self.mutations.append(node)
1623+
1624+
self.generic_visit(node)
1625+
1626+
def visit_Name(self, node: ast.Name):
1627+
if isinstance(node.ctx, ast.Del):
1628+
self.mutations.append(node)
1629+
self.generic_visit(node)
1630+
1631+
def visit(self, node):
1632+
"""Like super-visit but supports iteration over lists."""
1633+
if not isinstance(node, list):
1634+
return super().visit(node)
1635+
1636+
for elem in node:
1637+
super().visit(elem)
1638+
return node
1639+
1640+
15841641
@attr.s
15851642
class NameFinder(ast.NodeVisitor):
15861643
"""Finds a name within a tree of nodes.
@@ -2075,6 +2132,12 @@ def visit_Lambda(self, node):
20752132
" statement."
20762133
)
20772134
)
2135+
20782136
B950 = Error(message="B950 line too long ({} > {} characters)")
20792137

2138+
B038 = Error(
2139+
message=(
2140+
"B038 editing a loop's mutable iterable often leads to unexpected results/bugs"
2141+
)
2142+
)
20802143
disabled_by_default = ["B901", "B902", "B903", "B904", "B905", "B906", "B908", "B950"]

tests/b038.py

Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,46 @@
1+
"""
2+
Should emit:
3+
B999 - on lines 11, 25, 26, 40, 46
4+
"""
5+
6+
7+
some_list = [1, 2, 3]
8+
for elem in some_list:
9+
print(elem)
10+
if elem % 2 == 0:
11+
some_list.remove(elem) # should error
12+
13+
some_list = [1, 2, 3]
14+
some_other_list = [1, 2, 3]
15+
for elem in some_list:
16+
print(elem)
17+
if elem % 2 == 0:
18+
some_other_list.remove(elem) # should not error
19+
20+
21+
some_list = [1, 2, 3]
22+
for elem in some_list:
23+
print(elem)
24+
if elem % 2 == 0:
25+
del some_list[2] # should error
26+
del some_list
27+
28+
29+
class A:
30+
some_list: list
31+
32+
def __init__(self, ls):
33+
self.some_list = list(ls)
34+
35+
36+
a = A((1, 2, 3))
37+
for elem in a.some_list:
38+
print(elem)
39+
if elem % 2 == 0:
40+
a.some_list.remove(elem) # should error
41+
42+
a = A((1, 2, 3))
43+
for elem in a.some_list:
44+
print(elem)
45+
if elem % 2 == 0:
46+
del a.some_list[2] # should error

tests/test_bugbear.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,7 @@
4646
B035,
4747
B036,
4848
B037,
49+
B038,
4950
B901,
5051
B902,
5152
B903,
@@ -967,6 +968,21 @@ def test_selfclean_test_bugbear(self):
967968
self.assertEqual(proc.stdout, b"")
968969
self.assertEqual(proc.stderr, b"")
969970

971+
def test_b038(self):
972+
filename = Path(__file__).absolute().parent / "b038.py"
973+
mock_options = Namespace(select=[], extend_select=["B038"])
974+
bbc = BugBearChecker(filename=str(filename), options=mock_options)
975+
errors = list(bbc.run())
976+
print(errors)
977+
expected = [
978+
B038(11, 8),
979+
B038(25, 8),
980+
B038(26, 8),
981+
B038(40, 8),
982+
B038(46, 8),
983+
]
984+
self.assertEqual(errors, self.errors(*expected))
985+
970986

971987
class TestFuzz(unittest.TestCase):
972988
from hypothesis import HealthCheck, given, settings

0 commit comments

Comments
 (0)