Skip to content

Commit 2bd0d97

Browse files
authored
Merge pull request #7601 from bluetech/typing-longrepr
typing: resultlog, pytester, longrepr
2 parents 303030c + fbf251f commit 2bd0d97

File tree

8 files changed

+216
-103
lines changed

8 files changed

+216
-103
lines changed

src/_pytest/junitxml.py

Lines changed: 12 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@
2525
from _pytest import nodes
2626
from _pytest import timing
2727
from _pytest._code.code import ExceptionRepr
28+
from _pytest._code.code import ReprFileLocation
2829
from _pytest.config import Config
2930
from _pytest.config import filename_arg
3031
from _pytest.config.argparsing import Parser
@@ -200,8 +201,11 @@ def append_failure(self, report: TestReport) -> None:
200201
self._add_simple("skipped", "xfail-marked test passes unexpectedly")
201202
else:
202203
assert report.longrepr is not None
203-
if getattr(report.longrepr, "reprcrash", None) is not None:
204-
message = report.longrepr.reprcrash.message
204+
reprcrash = getattr(
205+
report.longrepr, "reprcrash", None
206+
) # type: Optional[ReprFileLocation]
207+
if reprcrash is not None:
208+
message = reprcrash.message
205209
else:
206210
message = str(report.longrepr)
207211
message = bin_xml_escape(message)
@@ -217,8 +221,11 @@ def append_collect_skipped(self, report: TestReport) -> None:
217221

218222
def append_error(self, report: TestReport) -> None:
219223
assert report.longrepr is not None
220-
if getattr(report.longrepr, "reprcrash", None) is not None:
221-
reason = report.longrepr.reprcrash.message
224+
reprcrash = getattr(
225+
report.longrepr, "reprcrash", None
226+
) # type: Optional[ReprFileLocation]
227+
if reprcrash is not None:
228+
reason = reprcrash.message
222229
else:
223230
reason = str(report.longrepr)
224231

@@ -237,7 +244,7 @@ def append_skipped(self, report: TestReport) -> None:
237244
skipped = ET.Element("skipped", type="pytest.xfail", message=xfailreason)
238245
self.append(skipped)
239246
else:
240-
assert report.longrepr is not None
247+
assert isinstance(report.longrepr, tuple)
241248
filename, lineno, skipreason = report.longrepr
242249
if skipreason.startswith("Skipped: "):
243250
skipreason = skipreason[9:]

src/_pytest/pytester.py

Lines changed: 86 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@
2828
from _pytest import timing
2929
from _pytest._code import Source
3030
from _pytest.capture import _get_multicapture
31+
from _pytest.compat import overload
3132
from _pytest.compat import TYPE_CHECKING
3233
from _pytest.config import _PluggyPlugin
3334
from _pytest.config import Config
@@ -42,11 +43,13 @@
4243
from _pytest.pathlib import make_numbered_dir
4344
from _pytest.pathlib import Path
4445
from _pytest.python import Module
46+
from _pytest.reports import CollectReport
4547
from _pytest.reports import TestReport
4648
from _pytest.tmpdir import TempdirFactory
4749

4850
if TYPE_CHECKING:
4951
from typing import Type
52+
from typing_extensions import Literal
5053

5154
import pexpect
5255

@@ -180,24 +183,24 @@ def gethookrecorder(self, hook) -> "HookRecorder":
180183
return hookrecorder
181184

182185

183-
def get_public_names(values):
186+
def get_public_names(values: Iterable[str]) -> List[str]:
184187
"""Only return names from iterator values without a leading underscore."""
185188
return [x for x in values if x[0] != "_"]
186189

187190

188191
class ParsedCall:
189-
def __init__(self, name, kwargs):
192+
def __init__(self, name: str, kwargs) -> None:
190193
self.__dict__.update(kwargs)
191194
self._name = name
192195

193-
def __repr__(self):
196+
def __repr__(self) -> str:
194197
d = self.__dict__.copy()
195198
del d["_name"]
196199
return "<ParsedCall {!r}(**{!r})>".format(self._name, d)
197200

198201
if TYPE_CHECKING:
199202
# The class has undetermined attributes, this tells mypy about it.
200-
def __getattr__(self, key):
203+
def __getattr__(self, key: str):
201204
raise NotImplementedError()
202205

203206

@@ -211,6 +214,7 @@ class HookRecorder:
211214
def __init__(self, pluginmanager: PytestPluginManager) -> None:
212215
self._pluginmanager = pluginmanager
213216
self.calls = [] # type: List[ParsedCall]
217+
self.ret = None # type: Optional[Union[int, ExitCode]]
214218

215219
def before(hook_name: str, hook_impls, kwargs) -> None:
216220
self.calls.append(ParsedCall(hook_name, kwargs))
@@ -228,7 +232,7 @@ def getcalls(self, names: Union[str, Iterable[str]]) -> List[ParsedCall]:
228232
names = names.split()
229233
return [call for call in self.calls if call._name in names]
230234

231-
def assert_contains(self, entries) -> None:
235+
def assert_contains(self, entries: Sequence[Tuple[str, str]]) -> None:
232236
__tracebackhide__ = True
233237
i = 0
234238
entries = list(entries)
@@ -266,22 +270,46 @@ def getcall(self, name: str) -> ParsedCall:
266270

267271
# functionality for test reports
268272

273+
@overload
269274
def getreports(
275+
self, names: "Literal['pytest_collectreport']",
276+
) -> Sequence[CollectReport]:
277+
raise NotImplementedError()
278+
279+
@overload # noqa: F811
280+
def getreports( # noqa: F811
281+
self, names: "Literal['pytest_runtest_logreport']",
282+
) -> Sequence[TestReport]:
283+
raise NotImplementedError()
284+
285+
@overload # noqa: F811
286+
def getreports( # noqa: F811
287+
self,
288+
names: Union[str, Iterable[str]] = (
289+
"pytest_collectreport",
290+
"pytest_runtest_logreport",
291+
),
292+
) -> Sequence[Union[CollectReport, TestReport]]:
293+
raise NotImplementedError()
294+
295+
def getreports( # noqa: F811
270296
self,
271-
names: Union[
272-
str, Iterable[str]
273-
] = "pytest_runtest_logreport pytest_collectreport",
274-
) -> List[TestReport]:
297+
names: Union[str, Iterable[str]] = (
298+
"pytest_collectreport",
299+
"pytest_runtest_logreport",
300+
),
301+
) -> Sequence[Union[CollectReport, TestReport]]:
275302
return [x.report for x in self.getcalls(names)]
276303

277304
def matchreport(
278305
self,
279306
inamepart: str = "",
280-
names: Union[
281-
str, Iterable[str]
282-
] = "pytest_runtest_logreport pytest_collectreport",
283-
when=None,
284-
):
307+
names: Union[str, Iterable[str]] = (
308+
"pytest_runtest_logreport",
309+
"pytest_collectreport",
310+
),
311+
when: Optional[str] = None,
312+
) -> Union[CollectReport, TestReport]:
285313
"""Return a testreport whose dotted import path matches."""
286314
values = []
287315
for rep in self.getreports(names=names):
@@ -305,26 +333,56 @@ def matchreport(
305333
)
306334
return values[0]
307335

336+
@overload
308337
def getfailures(
338+
self, names: "Literal['pytest_collectreport']",
339+
) -> Sequence[CollectReport]:
340+
raise NotImplementedError()
341+
342+
@overload # noqa: F811
343+
def getfailures( # noqa: F811
344+
self, names: "Literal['pytest_runtest_logreport']",
345+
) -> Sequence[TestReport]:
346+
raise NotImplementedError()
347+
348+
@overload # noqa: F811
349+
def getfailures( # noqa: F811
309350
self,
310-
names: Union[
311-
str, Iterable[str]
312-
] = "pytest_runtest_logreport pytest_collectreport",
313-
) -> List[TestReport]:
351+
names: Union[str, Iterable[str]] = (
352+
"pytest_collectreport",
353+
"pytest_runtest_logreport",
354+
),
355+
) -> Sequence[Union[CollectReport, TestReport]]:
356+
raise NotImplementedError()
357+
358+
def getfailures( # noqa: F811
359+
self,
360+
names: Union[str, Iterable[str]] = (
361+
"pytest_collectreport",
362+
"pytest_runtest_logreport",
363+
),
364+
) -> Sequence[Union[CollectReport, TestReport]]:
314365
return [rep for rep in self.getreports(names) if rep.failed]
315366

316-
def getfailedcollections(self) -> List[TestReport]:
367+
def getfailedcollections(self) -> Sequence[CollectReport]:
317368
return self.getfailures("pytest_collectreport")
318369

319370
def listoutcomes(
320371
self,
321-
) -> Tuple[List[TestReport], List[TestReport], List[TestReport]]:
372+
) -> Tuple[
373+
Sequence[TestReport],
374+
Sequence[Union[CollectReport, TestReport]],
375+
Sequence[Union[CollectReport, TestReport]],
376+
]:
322377
passed = []
323378
skipped = []
324379
failed = []
325-
for rep in self.getreports("pytest_collectreport pytest_runtest_logreport"):
380+
for rep in self.getreports(
381+
("pytest_collectreport", "pytest_runtest_logreport")
382+
):
326383
if rep.passed:
327384
if rep.when == "call":
385+
assert isinstance(rep, TestReport)
328386
passed.append(rep)
329387
elif rep.skipped:
330388
skipped.append(rep)
@@ -879,7 +937,7 @@ def runitem(self, source):
879937
runner = testclassinstance.getrunner()
880938
return runner(item)
881939

882-
def inline_runsource(self, source, *cmdlineargs):
940+
def inline_runsource(self, source, *cmdlineargs) -> HookRecorder:
883941
"""Run a test module in process using ``pytest.main()``.
884942
885943
This run writes "source" into a temporary file and runs
@@ -896,7 +954,7 @@ def inline_runsource(self, source, *cmdlineargs):
896954
values = list(cmdlineargs) + [p]
897955
return self.inline_run(*values)
898956

899-
def inline_genitems(self, *args):
957+
def inline_genitems(self, *args) -> Tuple[List[Item], HookRecorder]:
900958
"""Run ``pytest.main(['--collectonly'])`` in-process.
901959
902960
Runs the :py:func:`pytest.main` function to run all of pytest inside
@@ -907,7 +965,9 @@ def inline_genitems(self, *args):
907965
items = [x.item for x in rec.getcalls("pytest_itemcollected")]
908966
return items, rec
909967

910-
def inline_run(self, *args, plugins=(), no_reraise_ctrlc: bool = False):
968+
def inline_run(
969+
self, *args, plugins=(), no_reraise_ctrlc: bool = False
970+
) -> HookRecorder:
911971
"""Run ``pytest.main()`` in-process, returning a HookRecorder.
912972
913973
Runs the :py:func:`pytest.main` function to run all of pytest inside
@@ -962,7 +1022,7 @@ def pytest_configure(x, config: Config) -> None:
9621022
class reprec: # type: ignore
9631023
pass
9641024

965-
reprec.ret = ret # type: ignore[attr-defined]
1025+
reprec.ret = ret
9661026

9671027
# Typically we reraise keyboard interrupts from the child run
9681028
# because it's our user requesting interruption of the testing.
@@ -1010,6 +1070,7 @@ class reprec: # type: ignore
10101070
sys.stdout.write(out)
10111071
sys.stderr.write(err)
10121072

1073+
assert reprec.ret is not None
10131074
res = RunResult(
10141075
reprec.ret, out.splitlines(), err.splitlines(), timing.time() - now
10151076
)

src/_pytest/reports.py

Lines changed: 23 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
from io import StringIO
22
from pprint import pprint
33
from typing import Any
4+
from typing import cast
45
from typing import Dict
56
from typing import Iterable
67
from typing import Iterator
@@ -15,6 +16,7 @@
1516

1617
from _pytest._code.code import ExceptionChainRepr
1718
from _pytest._code.code import ExceptionInfo
19+
from _pytest._code.code import ExceptionRepr
1820
from _pytest._code.code import ReprEntry
1921
from _pytest._code.code import ReprEntryNative
2022
from _pytest._code.code import ReprExceptionInfo
@@ -57,8 +59,9 @@ def getworkerinfoline(node):
5759
class BaseReport:
5860
when = None # type: Optional[str]
5961
location = None # type: Optional[Tuple[str, Optional[int], str]]
60-
# TODO: Improve this Any.
61-
longrepr = None # type: Optional[Any]
62+
longrepr = (
63+
None
64+
) # type: Union[None, ExceptionInfo[BaseException], Tuple[str, int, str], str, TerminalRepr]
6265
sections = [] # type: List[Tuple[str, str]]
6366
nodeid = None # type: str
6467

@@ -79,7 +82,8 @@ def toterminal(self, out: TerminalWriter) -> None:
7982
return
8083

8184
if hasattr(longrepr, "toterminal"):
82-
longrepr.toterminal(out)
85+
longrepr_terminal = cast(TerminalRepr, longrepr)
86+
longrepr_terminal.toterminal(out)
8387
else:
8488
try:
8589
s = str(longrepr)
@@ -233,7 +237,9 @@ def __init__(
233237
location: Tuple[str, Optional[int], str],
234238
keywords,
235239
outcome: "Literal['passed', 'failed', 'skipped']",
236-
longrepr,
240+
longrepr: Union[
241+
None, ExceptionInfo[BaseException], Tuple[str, int, str], str, TerminalRepr
242+
],
237243
when: "Literal['setup', 'call', 'teardown']",
238244
sections: Iterable[Tuple[str, str]] = (),
239245
duration: float = 0,
@@ -293,8 +299,9 @@ def from_item_and_call(cls, item: Item, call: "CallInfo[None]") -> "TestReport":
293299
sections = []
294300
if not call.excinfo:
295301
outcome = "passed" # type: Literal["passed", "failed", "skipped"]
296-
# TODO: Improve this Any.
297-
longrepr = None # type: Optional[Any]
302+
longrepr = (
303+
None
304+
) # type: Union[None, ExceptionInfo[BaseException], Tuple[str, int, str], str, TerminalRepr]
298305
else:
299306
if not isinstance(excinfo, ExceptionInfo):
300307
outcome = "failed"
@@ -372,7 +379,7 @@ def __repr__(self) -> str:
372379

373380

374381
class CollectErrorRepr(TerminalRepr):
375-
def __init__(self, msg) -> None:
382+
def __init__(self, msg: str) -> None:
376383
self.longrepr = msg
377384

378385
def toterminal(self, out: TerminalWriter) -> None:
@@ -436,16 +443,18 @@ def serialize_repr_crash(
436443
else:
437444
return None
438445

439-
def serialize_longrepr(rep: BaseReport) -> Dict[str, Any]:
446+
def serialize_exception_longrepr(rep: BaseReport) -> Dict[str, Any]:
440447
assert rep.longrepr is not None
448+
# TODO: Investigate whether the duck typing is really necessary here.
449+
longrepr = cast(ExceptionRepr, rep.longrepr)
441450
result = {
442-
"reprcrash": serialize_repr_crash(rep.longrepr.reprcrash),
443-
"reprtraceback": serialize_repr_traceback(rep.longrepr.reprtraceback),
444-
"sections": rep.longrepr.sections,
451+
"reprcrash": serialize_repr_crash(longrepr.reprcrash),
452+
"reprtraceback": serialize_repr_traceback(longrepr.reprtraceback),
453+
"sections": longrepr.sections,
445454
} # type: Dict[str, Any]
446-
if isinstance(rep.longrepr, ExceptionChainRepr):
455+
if isinstance(longrepr, ExceptionChainRepr):
447456
result["chain"] = []
448-
for repr_traceback, repr_crash, description in rep.longrepr.chain:
457+
for repr_traceback, repr_crash, description in longrepr.chain:
449458
result["chain"].append(
450459
(
451460
serialize_repr_traceback(repr_traceback),
@@ -462,7 +471,7 @@ def serialize_longrepr(rep: BaseReport) -> Dict[str, Any]:
462471
if hasattr(report.longrepr, "reprtraceback") and hasattr(
463472
report.longrepr, "reprcrash"
464473
):
465-
d["longrepr"] = serialize_longrepr(report)
474+
d["longrepr"] = serialize_exception_longrepr(report)
466475
else:
467476
d["longrepr"] = str(report.longrepr)
468477
else:

0 commit comments

Comments
 (0)