28
28
from _pytest import timing
29
29
from _pytest ._code import Source
30
30
from _pytest .capture import _get_multicapture
31
+ from _pytest .compat import overload
31
32
from _pytest .compat import TYPE_CHECKING
32
33
from _pytest .config import _PluggyPlugin
33
34
from _pytest .config import Config
42
43
from _pytest .pathlib import make_numbered_dir
43
44
from _pytest .pathlib import Path
44
45
from _pytest .python import Module
46
+ from _pytest .reports import CollectReport
45
47
from _pytest .reports import TestReport
46
48
from _pytest .tmpdir import TempdirFactory
47
49
48
50
if TYPE_CHECKING :
49
51
from typing import Type
52
+ from typing_extensions import Literal
50
53
51
54
import pexpect
52
55
@@ -180,24 +183,24 @@ def gethookrecorder(self, hook) -> "HookRecorder":
180
183
return hookrecorder
181
184
182
185
183
- def get_public_names (values ) :
186
+ def get_public_names (values : Iterable [ str ]) -> List [ str ] :
184
187
"""Only return names from iterator values without a leading underscore."""
185
188
return [x for x in values if x [0 ] != "_" ]
186
189
187
190
188
191
class ParsedCall :
189
- def __init__ (self , name , kwargs ):
192
+ def __init__ (self , name : str , kwargs ) -> None :
190
193
self .__dict__ .update (kwargs )
191
194
self ._name = name
192
195
193
- def __repr__ (self ):
196
+ def __repr__ (self ) -> str :
194
197
d = self .__dict__ .copy ()
195
198
del d ["_name" ]
196
199
return "<ParsedCall {!r}(**{!r})>" .format (self ._name , d )
197
200
198
201
if TYPE_CHECKING :
199
202
# The class has undetermined attributes, this tells mypy about it.
200
- def __getattr__ (self , key ):
203
+ def __getattr__ (self , key : str ):
201
204
raise NotImplementedError ()
202
205
203
206
@@ -211,6 +214,7 @@ class HookRecorder:
211
214
def __init__ (self , pluginmanager : PytestPluginManager ) -> None :
212
215
self ._pluginmanager = pluginmanager
213
216
self .calls = [] # type: List[ParsedCall]
217
+ self .ret = None # type: Optional[Union[int, ExitCode]]
214
218
215
219
def before (hook_name : str , hook_impls , kwargs ) -> None :
216
220
self .calls .append (ParsedCall (hook_name , kwargs ))
@@ -228,7 +232,7 @@ def getcalls(self, names: Union[str, Iterable[str]]) -> List[ParsedCall]:
228
232
names = names .split ()
229
233
return [call for call in self .calls if call ._name in names ]
230
234
231
- def assert_contains (self , entries ) -> None :
235
+ def assert_contains (self , entries : Sequence [ Tuple [ str , str ]] ) -> None :
232
236
__tracebackhide__ = True
233
237
i = 0
234
238
entries = list (entries )
@@ -266,22 +270,46 @@ def getcall(self, name: str) -> ParsedCall:
266
270
267
271
# functionality for test reports
268
272
273
+ @overload
269
274
def getreports (
275
+ self , names : "Literal['pytest_collectreport']" ,
276
+ ) -> Sequence [CollectReport ]:
277
+ raise NotImplementedError ()
278
+
279
+ @overload # noqa: F811
280
+ def getreports ( # noqa: F811
281
+ self , names : "Literal['pytest_runtest_logreport']" ,
282
+ ) -> Sequence [TestReport ]:
283
+ raise NotImplementedError ()
284
+
285
+ @overload # noqa: F811
286
+ def getreports ( # noqa: F811
287
+ self ,
288
+ names : Union [str , Iterable [str ]] = (
289
+ "pytest_collectreport" ,
290
+ "pytest_runtest_logreport" ,
291
+ ),
292
+ ) -> Sequence [Union [CollectReport , TestReport ]]:
293
+ raise NotImplementedError ()
294
+
295
+ def getreports ( # noqa: F811
270
296
self ,
271
- names : Union [
272
- str , Iterable [str ]
273
- ] = "pytest_runtest_logreport pytest_collectreport" ,
274
- ) -> List [TestReport ]:
297
+ names : Union [str , Iterable [str ]] = (
298
+ "pytest_collectreport" ,
299
+ "pytest_runtest_logreport" ,
300
+ ),
301
+ ) -> Sequence [Union [CollectReport , TestReport ]]:
275
302
return [x .report for x in self .getcalls (names )]
276
303
277
304
def matchreport (
278
305
self ,
279
306
inamepart : str = "" ,
280
- names : Union [
281
- str , Iterable [str ]
282
- ] = "pytest_runtest_logreport pytest_collectreport" ,
283
- when = None ,
284
- ):
307
+ names : Union [str , Iterable [str ]] = (
308
+ "pytest_runtest_logreport" ,
309
+ "pytest_collectreport" ,
310
+ ),
311
+ when : Optional [str ] = None ,
312
+ ) -> Union [CollectReport , TestReport ]:
285
313
"""Return a testreport whose dotted import path matches."""
286
314
values = []
287
315
for rep in self .getreports (names = names ):
@@ -305,26 +333,56 @@ def matchreport(
305
333
)
306
334
return values [0 ]
307
335
336
+ @overload
308
337
def getfailures (
338
+ self , names : "Literal['pytest_collectreport']" ,
339
+ ) -> Sequence [CollectReport ]:
340
+ raise NotImplementedError ()
341
+
342
+ @overload # noqa: F811
343
+ def getfailures ( # noqa: F811
344
+ self , names : "Literal['pytest_runtest_logreport']" ,
345
+ ) -> Sequence [TestReport ]:
346
+ raise NotImplementedError ()
347
+
348
+ @overload # noqa: F811
349
+ def getfailures ( # noqa: F811
309
350
self ,
310
- names : Union [
311
- str , Iterable [str ]
312
- ] = "pytest_runtest_logreport pytest_collectreport" ,
313
- ) -> List [TestReport ]:
351
+ names : Union [str , Iterable [str ]] = (
352
+ "pytest_collectreport" ,
353
+ "pytest_runtest_logreport" ,
354
+ ),
355
+ ) -> Sequence [Union [CollectReport , TestReport ]]:
356
+ raise NotImplementedError ()
357
+
358
+ def getfailures ( # noqa: F811
359
+ self ,
360
+ names : Union [str , Iterable [str ]] = (
361
+ "pytest_collectreport" ,
362
+ "pytest_runtest_logreport" ,
363
+ ),
364
+ ) -> Sequence [Union [CollectReport , TestReport ]]:
314
365
return [rep for rep in self .getreports (names ) if rep .failed ]
315
366
316
- def getfailedcollections (self ) -> List [ TestReport ]:
367
+ def getfailedcollections (self ) -> Sequence [ CollectReport ]:
317
368
return self .getfailures ("pytest_collectreport" )
318
369
319
370
def listoutcomes (
320
371
self ,
321
- ) -> Tuple [List [TestReport ], List [TestReport ], List [TestReport ]]:
372
+ ) -> Tuple [
373
+ Sequence [TestReport ],
374
+ Sequence [Union [CollectReport , TestReport ]],
375
+ Sequence [Union [CollectReport , TestReport ]],
376
+ ]:
322
377
passed = []
323
378
skipped = []
324
379
failed = []
325
- for rep in self .getreports ("pytest_collectreport pytest_runtest_logreport" ):
380
+ for rep in self .getreports (
381
+ ("pytest_collectreport" , "pytest_runtest_logreport" )
382
+ ):
326
383
if rep .passed :
327
384
if rep .when == "call" :
385
+ assert isinstance (rep , TestReport )
328
386
passed .append (rep )
329
387
elif rep .skipped :
330
388
skipped .append (rep )
@@ -879,7 +937,7 @@ def runitem(self, source):
879
937
runner = testclassinstance .getrunner ()
880
938
return runner (item )
881
939
882
- def inline_runsource (self , source , * cmdlineargs ):
940
+ def inline_runsource (self , source , * cmdlineargs ) -> HookRecorder :
883
941
"""Run a test module in process using ``pytest.main()``.
884
942
885
943
This run writes "source" into a temporary file and runs
@@ -896,7 +954,7 @@ def inline_runsource(self, source, *cmdlineargs):
896
954
values = list (cmdlineargs ) + [p ]
897
955
return self .inline_run (* values )
898
956
899
- def inline_genitems (self , * args ):
957
+ def inline_genitems (self , * args ) -> Tuple [ List [ Item ], HookRecorder ] :
900
958
"""Run ``pytest.main(['--collectonly'])`` in-process.
901
959
902
960
Runs the :py:func:`pytest.main` function to run all of pytest inside
@@ -907,7 +965,9 @@ def inline_genitems(self, *args):
907
965
items = [x .item for x in rec .getcalls ("pytest_itemcollected" )]
908
966
return items , rec
909
967
910
- def inline_run (self , * args , plugins = (), no_reraise_ctrlc : bool = False ):
968
+ def inline_run (
969
+ self , * args , plugins = (), no_reraise_ctrlc : bool = False
970
+ ) -> HookRecorder :
911
971
"""Run ``pytest.main()`` in-process, returning a HookRecorder.
912
972
913
973
Runs the :py:func:`pytest.main` function to run all of pytest inside
@@ -962,7 +1022,7 @@ def pytest_configure(x, config: Config) -> None:
962
1022
class reprec : # type: ignore
963
1023
pass
964
1024
965
- reprec .ret = ret # type: ignore[attr-defined]
1025
+ reprec .ret = ret
966
1026
967
1027
# Typically we reraise keyboard interrupts from the child run
968
1028
# because it's our user requesting interruption of the testing.
@@ -1010,6 +1070,7 @@ class reprec: # type: ignore
1010
1070
sys .stdout .write (out )
1011
1071
sys .stderr .write (err )
1012
1072
1073
+ assert reprec .ret is not None
1013
1074
res = RunResult (
1014
1075
reprec .ret , out .splitlines (), err .splitlines (), timing .time () - now
1015
1076
)
0 commit comments