Skip to content

Commit 64762d2

Browse files
Merge pull request #4319 from blueyed/harden-test_collect_init_tests
Fix handling of duplicate args with regard to Python packages
2 parents 176d274 + 827573c commit 64762d2

File tree

4 files changed

+91
-43
lines changed

4 files changed

+91
-43
lines changed

changelog/4310.bugfix.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
Fix duplicate collection due to multiple args matching the same packages.

src/_pytest/main.py

Lines changed: 34 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,6 @@
1818
from _pytest.config import hookimpl
1919
from _pytest.config import UsageError
2020
from _pytest.outcomes import exit
21-
from _pytest.pathlib import parts
2221
from _pytest.runner import collect_one_node
2322

2423

@@ -387,6 +386,7 @@ def __init__(self, config):
387386
self._initialpaths = frozenset()
388387
# Keep track of any collected nodes in here, so we don't duplicate fixtures
389388
self._node_cache = {}
389+
self._pkg_roots = {}
390390

391391
self.config.pluginmanager.register(self, name="session")
392392

@@ -489,30 +489,26 @@ def _collect(self, arg):
489489

490490
names = self._parsearg(arg)
491491
argpath = names.pop(0).realpath()
492-
paths = set()
493492

494-
root = self
495493
# Start with a Session root, and delve to argpath item (dir or file)
496494
# and stack all Packages found on the way.
497495
# No point in finding packages when collecting doctests
498496
if not self.config.option.doctestmodules:
497+
pm = self.config.pluginmanager
499498
for parent in argpath.parts():
500-
pm = self.config.pluginmanager
501499
if pm._confcutdir and pm._confcutdir.relto(parent):
502500
continue
503501

504502
if parent.isdir():
505503
pkginit = parent.join("__init__.py")
506504
if pkginit.isfile():
507-
if pkginit in self._node_cache:
508-
root = self._node_cache[pkginit][0]
509-
else:
510-
col = root._collectfile(pkginit)
505+
if pkginit not in self._node_cache:
506+
col = self._collectfile(pkginit, handle_dupes=False)
511507
if col:
512508
if isinstance(col[0], Package):
513-
root = col[0]
509+
self._pkg_roots[parent] = col[0]
514510
# always store a list in the cache, matchnodes expects it
515-
self._node_cache[root.fspath] = [root]
511+
self._node_cache[col[0].fspath] = [col[0]]
516512

517513
# If it's a directory argument, recurse and look for any Subpackages.
518514
# Let the Package collector deal with subnodes, don't collect here.
@@ -535,28 +531,34 @@ def filter_(f):
535531
):
536532
dirpath = path.dirpath()
537533
if dirpath not in seen_dirs:
534+
# Collect packages first.
538535
seen_dirs.add(dirpath)
539536
pkginit = dirpath.join("__init__.py")
540-
if pkginit.exists() and parts(pkginit.strpath).isdisjoint(paths):
541-
for x in root._collectfile(pkginit):
542-
yield x
543-
paths.add(x.fspath.dirpath())
544-
545-
if parts(path.strpath).isdisjoint(paths):
546-
for x in root._collectfile(path):
547-
key = (type(x), x.fspath)
548-
if key in self._node_cache:
549-
yield self._node_cache[key]
550-
else:
551-
self._node_cache[key] = x
537+
if pkginit.exists():
538+
collect_root = self._pkg_roots.get(dirpath, self)
539+
for x in collect_root._collectfile(pkginit):
552540
yield x
541+
if isinstance(x, Package):
542+
self._pkg_roots[dirpath] = x
543+
if dirpath in self._pkg_roots:
544+
# Do not collect packages here.
545+
continue
546+
547+
for x in self._collectfile(path):
548+
key = (type(x), x.fspath)
549+
if key in self._node_cache:
550+
yield self._node_cache[key]
551+
else:
552+
self._node_cache[key] = x
553+
yield x
553554
else:
554555
assert argpath.check(file=1)
555556

556557
if argpath in self._node_cache:
557558
col = self._node_cache[argpath]
558559
else:
559-
col = root._collectfile(argpath)
560+
collect_root = self._pkg_roots.get(argpath.dirname, self)
561+
col = collect_root._collectfile(argpath)
560562
if col:
561563
self._node_cache[argpath] = col
562564
m = self.matchnodes(col, names)
@@ -570,20 +572,20 @@ def filter_(f):
570572
for y in m:
571573
yield y
572574

573-
def _collectfile(self, path):
575+
def _collectfile(self, path, handle_dupes=True):
574576
ihook = self.gethookproxy(path)
575577
if not self.isinitpath(path):
576578
if ihook.pytest_ignore_collect(path=path, config=self.config):
577579
return ()
578580

579-
# Skip duplicate paths.
580-
keepduplicates = self.config.getoption("keepduplicates")
581-
if not keepduplicates:
582-
duplicate_paths = self.config.pluginmanager._duplicatepaths
583-
if path in duplicate_paths:
584-
return ()
585-
else:
586-
duplicate_paths.add(path)
581+
if handle_dupes:
582+
keepduplicates = self.config.getoption("keepduplicates")
583+
if not keepduplicates:
584+
duplicate_paths = self.config.pluginmanager._duplicatepaths
585+
if path in duplicate_paths:
586+
return ()
587+
else:
588+
duplicate_paths.add(path)
587589

588590
return ihook.pytest_collect_file(path=path, parent=self)
589591

src/_pytest/python.py

Lines changed: 14 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -545,11 +545,24 @@ def gethookproxy(self, fspath):
545545
proxy = self.config.hook
546546
return proxy
547547

548-
def _collectfile(self, path):
548+
def _collectfile(self, path, handle_dupes=True):
549549
ihook = self.gethookproxy(path)
550550
if not self.isinitpath(path):
551551
if ihook.pytest_ignore_collect(path=path, config=self.config):
552552
return ()
553+
554+
if handle_dupes:
555+
keepduplicates = self.config.getoption("keepduplicates")
556+
if not keepduplicates:
557+
duplicate_paths = self.config.pluginmanager._duplicatepaths
558+
if path in duplicate_paths:
559+
return ()
560+
else:
561+
duplicate_paths.add(path)
562+
563+
if self.fspath == path: # __init__.py
564+
return [self]
565+
553566
return ihook.pytest_collect_file(path=path, parent=self)
554567

555568
def isinitpath(self, path):

testing/test_collection.py

Lines changed: 42 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -951,26 +951,58 @@ def test_collect_init_tests(testdir):
951951
result = testdir.runpytest(p, "--collect-only")
952952
result.stdout.fnmatch_lines(
953953
[
954-
"*<Module '__init__.py'>",
955-
"*<Function 'test_init'>",
956-
"*<Module 'test_foo.py'>",
957-
"*<Function 'test_foo'>",
954+
"collected 2 items",
955+
"<Package *",
956+
" <Module '__init__.py'>",
957+
" <Function 'test_init'>",
958+
" <Module 'test_foo.py'>",
959+
" <Function 'test_foo'>",
958960
]
959961
)
960962
result = testdir.runpytest("./tests", "--collect-only")
961963
result.stdout.fnmatch_lines(
962964
[
963-
"*<Module '__init__.py'>",
964-
"*<Function 'test_init'>",
965-
"*<Module 'test_foo.py'>",
966-
"*<Function 'test_foo'>",
965+
"collected 2 items",
966+
"<Package *",
967+
" <Module '__init__.py'>",
968+
" <Function 'test_init'>",
969+
" <Module 'test_foo.py'>",
970+
" <Function 'test_foo'>",
971+
]
972+
)
973+
# Ignores duplicates with "." and pkginit (#4310).
974+
result = testdir.runpytest("./tests", ".", "--collect-only")
975+
result.stdout.fnmatch_lines(
976+
[
977+
"collected 2 items",
978+
"<Package */tests'>",
979+
" <Module '__init__.py'>",
980+
" <Function 'test_init'>",
981+
" <Module 'test_foo.py'>",
982+
" <Function 'test_foo'>",
983+
]
984+
)
985+
# Same as before, but different order.
986+
result = testdir.runpytest(".", "tests", "--collect-only")
987+
result.stdout.fnmatch_lines(
988+
[
989+
"collected 2 items",
990+
"<Package */tests'>",
991+
" <Module '__init__.py'>",
992+
" <Function 'test_init'>",
993+
" <Module 'test_foo.py'>",
994+
" <Function 'test_foo'>",
967995
]
968996
)
969997
result = testdir.runpytest("./tests/test_foo.py", "--collect-only")
970-
result.stdout.fnmatch_lines(["*<Module 'test_foo.py'>", "*<Function 'test_foo'>"])
998+
result.stdout.fnmatch_lines(
999+
["<Package */tests'>", " <Module 'test_foo.py'>", " <Function 'test_foo'>"]
1000+
)
9711001
assert "test_init" not in result.stdout.str()
9721002
result = testdir.runpytest("./tests/__init__.py", "--collect-only")
973-
result.stdout.fnmatch_lines(["*<Module '__init__.py'>", "*<Function 'test_init'>"])
1003+
result.stdout.fnmatch_lines(
1004+
["<Package */tests'>", " <Module '__init__.py'>", " <Function 'test_init'>"]
1005+
)
9741006
assert "test_foo" not in result.stdout.str()
9751007

9761008

0 commit comments

Comments
 (0)